This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e49665ac39b [SPARK-47960][SS] Allow chaining other stateful operators 
after transformWithState operator
5e49665ac39b is described below

commit 5e49665ac39b49b875d6970f93df59aedd830fa5
Author: Bhuwan Sahni <bhuwan.sa...@databricks.com>
AuthorDate: Wed May 8 09:20:01 2024 +0900

    [SPARK-47960][SS] Allow chaining other stateful operators after 
transformWithState operator
    
    ### What changes were proposed in this pull request?
    
    This PR adds support to define event time column in the output dataset of 
`TransformWithState` operator. The new event time column will be used to 
evaluate watermark expressions in downstream operators.
    
    1. Note that the transformWithState operator does not enforce that values 
generated by user's computation adhere to the watermark semantics. (no output 
rows are generated which have event time less than watermark).
    2. Updated the watermark value passed in TimerInfo as evictionWatermark, 
rather than lateEventsWatermark.
    3. Ensure that event time column can only be defined in output if a 
watermark has been defined previously.
    
    ### Why are the changes needed?
    
    This change is required to support chaining of stateful operators after 
`transformWithState`. Event time column is required to evaluate watermark 
expressions in downstream stateful operators.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Adds a new version of transformWithState API which allows redefining 
the event time column.
    
    ### How was this patch tested?
    
    Added unit test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45376 from sahnib/tws-chaining-stateful-operators.
    
    Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../src/main/resources/error/error-conditions.json |  14 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   3 +
 .../ResolveUpdateEventTimeWatermarkColumn.scala    |  52 +++
 .../plans/logical/EventTimeWatermark.scala         |  79 +++-
 .../spark/sql/catalyst/plans/logical/object.scala  |   2 +-
 .../sql/catalyst/rules/RuleIdCollection.scala      |   1 +
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   1 +
 .../spark/sql/errors/QueryCompilationErrors.scala  |   7 +
 .../spark/sql/errors/QueryExecutionErrors.scala    |  11 +
 .../apache/spark/sql/KeyValueGroupedDataset.scala  | 178 ++++++++-
 .../spark/sql/execution/SparkStrategies.scala      |  12 +
 .../streaming/EventTimeWatermarkExec.scala         |  88 ++++-
 .../execution/streaming/IncrementalExecution.scala |  24 +-
 .../streaming/TransformWithStateExec.scala         |  32 +-
 .../execution/streaming/statefulOperators.scala    |   2 +-
 .../TransformWithStateChainingSuite.scala          | 411 +++++++++++++++++++++
 16 files changed, 866 insertions(+), 51 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index bae94a0ab97e..8a64c4c590e8 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -125,6 +125,12 @@
     ],
     "sqlState" : "428FR"
   },
+  "CANNOT_ASSIGN_EVENT_TIME_COLUMN_WITHOUT_WATERMARK" : {
+    "message" : [
+      "Watermark needs to be defined to reassign event time column. Failed to 
find watermark definition in the streaming query."
+    ],
+    "sqlState" : "42611"
+  },
   "CANNOT_CAST_DATATYPE" : {
     "message" : [
       "Cannot cast <sourceType> to <targetType>."
@@ -1057,6 +1063,14 @@
     },
     "sqlState" : "4274K"
   },
+  "EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED" : {
+    "message" : [
+      "Previous node emitted a row with eventTime=<emittedRowEventTime> which 
is older than current_watermark_value=<currentWatermark>",
+      "This can lead to correctness issues in the stateful operators 
downstream in the execution pipeline.",
+      "Please correct the operator logic to emit rows after current global 
watermark value."
+    ],
+    "sqlState" : "42815"
+  },
   "EMPTY_JSON_FIELD_VALUE" : {
     "message" : [
       "Failed to parse an empty string for data type <dataType>."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c29432c916f9..55b6f1af7fd8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -331,6 +331,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       Seq(
         ResolveWithCTE,
         ExtractDistributedSequenceID) ++
+      Seq(ResolveUpdateEventTimeWatermarkColumn) ++
       extendedResolutionRules : _*),
     Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn),
     Batch("Post-Hoc Resolution", Once,
@@ -4003,6 +4004,8 @@ object EliminateEventTimeWatermark extends 
Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsWithPruning(
     _.containsPattern(EVENT_TIME_WATERMARK)) {
     case EventTimeWatermark(_, _, child) if child.resolved && 
!child.isStreaming => child
+    case UpdateEventTimeWatermarkColumn(_, _, child) if child.resolved && 
!child.isStreaming =>
+      child
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala
new file mode 100644
index 000000000000..31c4f068a83e
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
LogicalPlan, UpdateEventTimeWatermarkColumn}
+import org.apache.spark.sql.catalyst.rules.Rule
+import 
org.apache.spark.sql.catalyst.trees.TreePattern.UPDATE_EVENT_TIME_WATERMARK_COLUMN
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * Extracts the watermark delay and adds it to the 
UpdateEventTimeWatermarkColumn
+ * logical node (if such a node is present). 
[[UpdateEventTimeWatermarkColumn]] node updates
+ * the eventTimeColumn for upstream operators.
+ *
+ * If the logical plan contains a [[UpdateEventTimeWatermarkColumn]] node, but 
no watermark
+ * has been defined, the query will fail with a compilation error.
+ */
+object ResolveUpdateEventTimeWatermarkColumn extends Rule[LogicalPlan] {
+
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
+    _.containsPattern(UPDATE_EVENT_TIME_WATERMARK_COLUMN), ruleId) {
+    case u: UpdateEventTimeWatermarkColumn if u.delay.isEmpty && 
u.childrenResolved =>
+      val existingWatermarkDelay = u.child.collect {
+        case EventTimeWatermark(_, delay, _) => delay
+      }
+
+      if (existingWatermarkDelay.isEmpty) {
+        // input dataset needs to have a event time column, we transfer the
+        // watermark delay from this column to user specified 
eventTimeColumnName
+        // in the output dataset.
+        throw QueryCompilationErrors.cannotAssignEventTimeColumn()
+      }
+
+      val delay = existingWatermarkDelay.head
+      u.copy(delay = Some(delay))
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index 32a9030ff62b..8cfc939755ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.plans.logical
 import java.util.concurrent.TimeUnit
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.trees.TreePattern.{EVENT_TIME_WATERMARK, 
TreePattern}
+import 
org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark.updateEventTimeColumn
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EVENT_TIME_WATERMARK, 
TreePattern, UPDATE_EVENT_TIME_WATERMARK_COLUMN}
 import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.types.MetadataBuilder
 import org.apache.spark.unsafe.types.CalendarInterval
@@ -32,6 +33,36 @@ object EventTimeWatermark {
   def getDelayMs(delay: CalendarInterval): Long = {
     IntervalUtils.getDuration(delay, TimeUnit.MILLISECONDS)
   }
+
+  /**
+   * Adds watermark delay to the metadata for newEventTime in provided 
attributes.
+   *
+   * If any other existing attributes have watermark delay present in their 
metadata, watermark
+   * delay will be removed from their metadata.
+   */
+  def updateEventTimeColumn(
+      attributes: Seq[Attribute],
+      delayMs: Long,
+      newEventTime: Attribute): Seq[Attribute] = {
+    attributes.map { a =>
+      if (a semanticEquals newEventTime) {
+        val updatedMetadata = new MetadataBuilder()
+          .withMetadata(a.metadata)
+          .putLong(EventTimeWatermark.delayKey, delayMs)
+          .build()
+        a.withMetadata(updatedMetadata)
+      } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
+        // Remove existing columns tagged as eventTime for watermark
+        val updatedMetadata = new MetadataBuilder()
+          .withMetadata(a.metadata)
+          .remove(EventTimeWatermark.delayKey)
+          .build()
+        a.withMetadata(updatedMetadata)
+      } else {
+        a
+      }
+    }
+  }
 }
 
 /**
@@ -49,26 +80,38 @@ case class EventTimeWatermark(
   // logic here because we also maintain the compatibility flag. (See
   // SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE for details.)
   // TODO: Disallow updating the metadata once we remove the compatibility 
flag.
-  override val output: Seq[Attribute] = child.output.map { a =>
-    if (a semanticEquals eventTime) {
-      val delayMs = EventTimeWatermark.getDelayMs(delay)
-      val updatedMetadata = new MetadataBuilder()
-        .withMetadata(a.metadata)
-        .putLong(EventTimeWatermark.delayKey, delayMs)
-        .build()
-      a.withMetadata(updatedMetadata)
-    } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
-      // Remove existing watermark
-      val updatedMetadata = new MetadataBuilder()
-        .withMetadata(a.metadata)
-        .remove(EventTimeWatermark.delayKey)
-        .build()
-      a.withMetadata(updatedMetadata)
+  override val output: Seq[Attribute] = {
+    val delayMs = EventTimeWatermark.getDelayMs(delay)
+    updateEventTimeColumn(child.output, delayMs, eventTime)
+  }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
EventTimeWatermark =
+    copy(child = newChild)
+}
+
+/**
+ * Updates the event time column to [[eventTime]] in the child output.
+ *
+ * Any watermark calculations performed after this node will use the
+ * updated eventTimeColumn.
+ */
+case class UpdateEventTimeWatermarkColumn(
+    eventTime: Attribute,
+    delay: Option[CalendarInterval],
+    child: LogicalPlan) extends UnaryNode {
+
+  final override val nodePatterns: Seq[TreePattern] = 
Seq(UPDATE_EVENT_TIME_WATERMARK_COLUMN)
+
+  override def output: Seq[Attribute] = {
+    if (delay.isDefined) {
+      val delayMs = EventTimeWatermark.getDelayMs(delay.get)
+      updateEventTimeColumn(child.output, delayMs, eventTime)
     } else {
-      a
+      child.output
     }
   }
 
-  override protected def withNewChildInternal(newChild: LogicalPlan): 
EventTimeWatermark =
+  override protected def withNewChildInternal(
+      newChild: LogicalPlan): UpdateEventTimeWatermarkColumn =
     copy(child = newChild)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 28d52d39093b..07423b612c30 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -579,7 +579,7 @@ object TransformWithState {
       child: LogicalPlan): LogicalPlan = {
     val keyEncoder = encoderFor[K]
     val mapped = new TransformWithState(
-      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+      UnresolvedDeserializer(keyEncoder.deserializer, groupingAttributes),
       UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
       groupingAttributes,
       dataAttributes,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 11df764ebb03..c862c4043a2a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -104,6 +104,7 @@ object RuleIdCollection {
       
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule"
 ::
       "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
       "org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability" ::
+      
"org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" 
::
       // Catalyst Optimizer rules
       "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
       "org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 4ab075db5709..9c0038052893 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -133,6 +133,7 @@ object TreePattern extends Enumeration  {
   val UNION: Value = Value
   val UNRESOLVED_RELATION: Value = Value
   val UNRESOLVED_WITH: Value = Value
+  val UPDATE_EVENT_TIME_WATERMARK_COLUMN: Value = Value
   val TEMP_RESOLVED_COLUMN: Value = Value
   val TYPED_FILTER: Value = Value
   val WINDOW: Value = Value
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 24c85b76437c..95101f1fd19f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -4061,4 +4061,11 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
     callDeprecatedMethodError("createTable(..., StructType, ...)",
       "createTable(..., Array[Column], ...)")
   }
+
+  def cannotAssignEventTimeColumn(): Throwable = {
+    new AnalysisException(
+      errorClass = "CANNOT_ASSIGN_EVENT_TIME_COLUMN_WITHOUT_WATERMARK",
+      messageParameters = Map()
+    )
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 53ac788956d1..1f3283ebed05 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2762,4 +2762,15 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
         "numFields" -> numFields.toString,
         "schemaLen" -> schemaLen.toString))
   }
+
+  def emittedRowsAreOlderThanWatermark(
+      currentWatermark: Long, emittedRowEventTime: Long): 
SparkRuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED",
+      messageParameters = Map(
+        "currentWatermark" -> currentWatermark.toString,
+        "emittedRowEventTime" -> emittedRowEventTime.toString
+      )
+    )
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 862268eba666..52ab633cd75a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.api.java.function._
+import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -676,6 +677,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
     )
   }
 
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * We allow the user to act on per-group set of input rows along with keyed 
state and the
+   * user can choose to output/return 0 or more rows.
+   * For a streaming dataframe, we will repeatedly invoke the interface 
methods for new rows
+   * in each trigger and the user's state/state variables will be stored 
persistently across
+   * invocations.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate 
watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of 
watermark.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @param statefulProcessor   Instance of statefulProcessor whose functions 
will
+   *                            be invoked by the operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any 
operations after
+   *                            transformWithState will use the new 
eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted 
output adheres to
+   *                            the watermark boundary, otherwise streaming 
query will fail.
+   * @param outputMode          The output mode of the stateful processor.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode): Dataset[U] = {
+    val transformWithState = TransformWithState[K, V, U](
+      groupingAttributes,
+      dataAttributes,
+      statefulProcessor,
+      TimeMode.EventTime(),
+      outputMode,
+      child = logicalPlan
+    )
+    updateEventTimeColumnAfterTransformWithState(transformWithState, 
eventTimeColumnName)
+  }
+
   /**
    * (Java-specific)
    * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
@@ -702,6 +741,39 @@ class KeyValueGroupedDataset[K, V] private[sql](
     transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder)
   }
 
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * We allow the user to act on per-group set of input rows along with keyed 
state and the
+   * user can choose to output/return 0 or more rows.
+   *
+   * For a streaming dataframe, we will repeatedly invoke the interface 
methods for new rows
+   * in each trigger and the user's state/state variables will be stored 
persistently across
+   * invocations.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate 
watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of 
watermark.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions 
will be invoked by the
+   *                          operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any 
operations after
+   *                            transformWithState will use the new 
eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted 
output adheres to
+   *                            the watermark boundary, otherwise streaming 
query will fail.
+   * @param outputMode        The output mode of the stateful processor.
+   * @param outputEncoder     Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    transformWithState(statefulProcessor, eventTimeColumnName, 
outputMode)(outputEncoder)
+  }
+
   /**
    * (Scala-specific)
    * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
@@ -739,19 +811,98 @@ class KeyValueGroupedDataset[K, V] private[sql](
     )
   }
 
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * Functions as the function above, but with additional eventTimeColumnName 
for output.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark 
SQL types.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate 
watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of 
watermark.
+   *
+   * @param statefulProcessor   Instance of statefulProcessor whose functions 
will
+   *                            be invoked by the operator.
+   * @param eventTimeColumnName eventTime column in the output dataset. Any 
operations after
+   *                            transformWithState will use the new 
eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted 
output adheres to
+   *                            the watermark boundary, otherwise streaming 
query will fail.
+   * @param outputMode          The output mode of the stateful processor.
+   * @param initialState        User provided initial state that will be used 
to initiate state for
+   *                            the query in the first batch.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      eventTimeColumnName: String,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+    val transformWithState = TransformWithState[K, V, U, S](
+      groupingAttributes,
+      dataAttributes,
+      statefulProcessor,
+      TimeMode.EventTime(),
+      outputMode,
+      child = logicalPlan,
+      initialState.groupingAttributes,
+      initialState.dataAttributes,
+      initialState.queryExecution.analyzed
+    )
+
+    updateEventTimeColumnAfterTransformWithState(transformWithState, 
eventTimeColumnName)
+  }
+
   /**
    * (Java-specific)
    * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
-   * Functions as the function above, but with additional initial state.
+   * Functions as the function above, but with additional initialStateEncoder 
for state encoding.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark 
SQL types.
+   * @param statefulProcessor   Instance of statefulProcessor whose functions 
will
+   *                            be invoked by the operator.
+   * @param timeMode            The time mode semantics of the stateful 
processor for
+   *                            timers and TTL.
+   * @param outputMode          The output mode of the stateful processor.
+   * @param initialState        User provided initial state that will be used 
to initiate state for
+   *                            the query in the first batch.
+   * @param outputEncoder       Encoder for the output type.
+   * @param initialStateEncoder Encoder for the initial state type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S],
+      outputEncoder: Encoder[U],
+      initialStateEncoder: Encoder[S]): Dataset[U] = {
+    transformWithState(statefulProcessor, timeMode,
+      outputMode, initialState)(outputEncoder, initialStateEncoder)
+  }
+
+  /**
+   * (Java-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * Functions as the function above, but with additional eventTimeColumnName 
for output.
+   *
+   * Downstream operators would use specified eventTimeColumnName to calculate 
watermark.
+   * Note that TimeMode is set to EventTime to ensure correct flow of 
watermark.
    *
    * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
    * @tparam S The type of initial state objects. Must be encodable to Spark 
SQL types.
    * @param statefulProcessor Instance of statefulProcessor whose functions 
will
    *                          be invoked by the operator.
-   * @param timeMode          The time mode semantics of the stateful 
processor for timers and TTL.
    * @param outputMode        The output mode of the stateful processor.
    * @param initialState      User provided initial state that will be used to 
initiate state for
    *                          the query in the first batch.
+   * @param eventTimeColumnName event column in the output dataset. Any 
operations after
+   *                            transformWithState will use the new 
eventTimeColumn. The user
+   *                            needs to ensure that the eventTime for emitted 
output adheres to
+   *                            the watermark boundary, otherwise streaming 
query will fail.
    * @param outputEncoder     Encoder for the output type.
    * @param initialStateEncoder Encoder for the initial state type.
    *
@@ -759,15 +910,34 @@ class KeyValueGroupedDataset[K, V] private[sql](
    */
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
-      timeMode: TimeMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S],
+      eventTimeColumnName: String,
       outputEncoder: Encoder[U],
       initialStateEncoder: Encoder[S]): Dataset[U] = {
-    transformWithState(statefulProcessor, timeMode,
+    transformWithState(statefulProcessor, eventTimeColumnName,
       outputMode, initialState)(outputEncoder, initialStateEncoder)
   }
 
+  /**
+   * Creates a new dataset with updated eventTimeColumn after the 
transformWithState
+   * logical node.
+   */
+  private def updateEventTimeColumnAfterTransformWithState[U: Encoder](
+      transformWithState: LogicalPlan,
+      eventTimeColumnName: String): Dataset[U] = {
+    val transformWithStateDataset = Dataset[U](
+      sparkSession,
+      transformWithState
+    )
+
+    Dataset[U](sparkSession, EliminateEventTimeWatermark(
+      UpdateEventTimeWatermarkColumn(
+        UnresolvedAttribute(eventTimeColumnName),
+        None,
+        transformWithStateDataset.logicalPlan)))
+  }
+
   /**
    * (Scala-specific)
    * Reduces the elements of each group of data using the specified binary 
function.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 41793c2df017..348cc00a1f97 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -442,6 +442,18 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case EventTimeWatermark(columnName, delay, child) =>
         EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil
 
+      case UpdateEventTimeWatermarkColumn(columnName, delay, child) =>
+        // we expect watermarkDelay to be resolved before physical planning.
+        if (delay.isEmpty) {
+          // This is a sanity check. We should not reach here as delay is 
updated during
+          // query plan resolution in 
[[ResolveUpdateEventTimeWatermarkColumn]] Analyzer rule.
+          throw SparkException.internalError(
+            "No watermark delay found in UpdateEventTimeWatermarkColumn 
logical node. " +
+              "You have hit a query analyzer bug. " +
+              "Please report your query to Spark user mailing list.")
+        }
+        UpdateEventTimeColumnExec(columnName, delay.get, None, 
planLater(child)) :: Nil
+
       case PhysicalAggregation(
         namedGroupingExpressions, aggregateExpressions, 
rewrittenResultExpressions, child) =>
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 7e094fee3254..54041abdc9ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
Predicate, SortOrder, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
+import 
org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark.updateEventTimeColumn
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToMillis
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.types.MetadataBuilder
 import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.AccumulatorV2
 
@@ -107,25 +110,72 @@ case class EventTimeWatermarkExec(
   }
 
   // Update the metadata on the eventTime column to include the desired delay.
-  override val output: Seq[Attribute] = child.output.map { a =>
-    if (a semanticEquals eventTime) {
-      val updatedMetadata = new MetadataBuilder()
-        .withMetadata(a.metadata)
-        .putLong(EventTimeWatermark.delayKey, delayMs)
-        .build()
-      a.withMetadata(updatedMetadata)
-    } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
-      // Remove existing watermark
-      val updatedMetadata = new MetadataBuilder()
-        .withMetadata(a.metadata)
-        .remove(EventTimeWatermark.delayKey)
-        .build()
-      a.withMetadata(updatedMetadata)
-    } else {
-      a
-    }
+  override val output: Seq[Attribute] = {
+    val delayMs = EventTimeWatermark.getDelayMs(delay)
+    updateEventTimeColumn(child.output, delayMs, eventTime)
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
EventTimeWatermarkExec =
     copy(child = newChild)
 }
+
+/**
+ * Updates the event time column to [[eventTime]] in the child output.
+ * Any watermark calculations performed after this node will use the
+ * updated eventTimeColumn.
+ *
+ * This node also ensures that output emitted by the child node adheres
+ * to watermark. If the child node emits rows which are older than global
+ * watermark, the node will throw an query execution error and fail the user
+ * query.
+ */
+case class UpdateEventTimeColumnExec(
+    eventTime: Attribute,
+    delay: CalendarInterval,
+    eventTimeWatermarkForLateEvents: Option[Long],
+    child: SparkPlan) extends UnaryExecNode {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitions[InternalRow] { dataIterator =>
+      val watermarkExpression = WatermarkSupport.watermarkExpression(
+        Some(eventTime), eventTimeWatermarkForLateEvents)
+
+      if (watermarkExpression.isEmpty) {
+        // watermark should always be defined in this node.
+        throw QueryExecutionErrors.cannotGetEventTimeWatermarkError()
+      }
+
+      val predicate = Predicate.create(watermarkExpression.get, child.output)
+      new Iterator[InternalRow] {
+        override def hasNext: Boolean = dataIterator.hasNext
+
+        override def next(): InternalRow = {
+          val row = dataIterator.next()
+          if (predicate.eval(row)) {
+            // child node emitted a row which is older than current watermark
+            // this is not allowed
+            val boundEventTimeExpression = 
bindReference[Expression](eventTime, child.output)
+            val eventTimeProjection = 
UnsafeProjection.create(boundEventTimeExpression)
+            val rowEventTime = eventTimeProjection(row)
+            throw QueryExecutionErrors.emittedRowsAreOlderThanWatermark(
+              eventTimeWatermarkForLateEvents.get, rowEventTime.getLong(0))
+          }
+          row
+        }
+      }
+    }
+  }
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  // Update the metadata on the eventTime column to include the desired delay.
+  override val output: Seq[Attribute] = {
+    val delayMs = EventTimeWatermark.getDelayMs(delay)
+    updateEventTimeColumn(child.output, delayMs, eventTime)
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): 
UpdateEventTimeColumnExec =
+    copy(child = newChild)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 6b2e899eaf0c..f72e2eb407f8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, 
SparkPlan, SparkPlanner, UnaryExecNode}
+import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, 
SerializeFromObjectExec, SparkPlan, SparkPlanner, UnaryExecNode}
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, 
UpdatingSessionsExec}
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
@@ -347,6 +347,28 @@ class IncrementalExecution(
           eventTimeWatermarkForEviction = 
inputWatermarkForEviction(m.stateInfo.get)
         )
 
+      // UpdateEventTimeColumnExec is used to tag the eventTime column, and 
validate
+      // emitted rows adhere to watermark in the output of transformWithState.
+      // Hence, this node shares the same watermark value as 
TransformWithStateExec.
+      // However, given that UpdateEventTimeColumnExec does not store any 
state, it
+      // does not have any StateInfo. We simply use the StateInfo of 
transformWithStateExec
+      // to propagate watermark to both UpdateEventTimeColumnExec and 
transformWithStateExec.
+      case UpdateEventTimeColumnExec(eventTime, delay, None,
+        SerializeFromObjectExec(serializer,
+        t: TransformWithStateExec)) if t.stateInfo.isDefined =>
+
+        val stateInfo = t.stateInfo.get
+        val iwLateEvents = inputWatermarkForLateEvents(stateInfo)
+        val iwEviction = inputWatermarkForEviction(stateInfo)
+
+        UpdateEventTimeColumnExec(eventTime, delay, iwLateEvents,
+          SerializeFromObjectExec(serializer,
+            t.copy(
+              eventTimeWatermarkForLateEvents = iwLateEvents,
+              eventTimeWatermarkForEviction = iwEviction)
+          ))
+
+
       case t: TransformWithStateExec if t.stateInfo.isDefined =>
         t.copy(
           eventTimeWatermarkForLateEvents = 
inputWatermarkForLateEvents(t.stateInfo.get),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 840e065b5e78..fbd062acfb5d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -78,15 +78,33 @@ case class TransformWithStateExec(
   override def shortName: String = "transformWithStateExec"
 
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+    if (timeMode == ProcessingTime) {
+      // TODO: check if we can return true only if actual timers are 
registered, or there is
+      // expired state
+      true
+    } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
+      eventTimeWatermarkForEviction.isDefined &&
+      newInputWatermark > eventTimeWatermarkForEviction.get
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Controls watermark propagation to downstream modes. If timeMode is
+   * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
+   * this node will not propagate watermark in this timeMode.
+   *
+   * For timeMode EventTime, output watermark is same as input Watermark 
because
+   * transformWithState does not allow users to set the event time column to be
+   * earlier than the watermark.
+   */
+  override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
     timeMode match {
       case ProcessingTime =>
-        // TODO: check if we can return true only if actual timers are 
registered, or there is
-        // expired state
-        true
-      case EventTime =>
-        eventTimeWatermarkForEviction.isDefined &&
-          newInputWatermark > eventTimeWatermarkForEviction.get
-      case _ => false
+        None
+      case _ =>
+        Some(inputWatermarkMs)
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3bf833816bcc..9add574e01fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -359,7 +359,7 @@ object WatermarkSupport {
     if (optionalWatermarkExpression.isEmpty || optionalWatermarkMs.isEmpty) 
return None
 
     val watermarkAttribute = optionalWatermarkExpression.get
-    // If we are evicting based on a window, use the end of the window.  
Otherwise just
+    // If we are evicting based on a window, use the end of the window. 
Otherwise just
     // use the attribute itself.
     val evictionExpression =
       if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
new file mode 100644
index 000000000000..5388d6f1fb68
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
@@ -0,0 +1,411 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.sql.Timestamp
+import java.time.{Instant, LocalDateTime, ZoneId}
+
+import org.apache.spark.{SparkRuntimeException, SparkThrowable}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamExecution}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.functions.window
+import org.apache.spark.sql.internal.SQLConf
+
+case class InputEventRow(
+    key: String,
+    eventTime: Timestamp,
+    event: String)
+
+case class OutputRow(
+    key: String,
+    outputEventTime: Timestamp,
+    count: Int)
+
+class TestStatefulProcessor
+  extends StatefulProcessor[String, InputEventRow, OutputRow] {
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputEventRow],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputRow] = {
+    if (inputRows.isEmpty) {
+      Iterator.empty
+    } else {
+      var minEventTime = inputRows.next().eventTime
+      var count = 1
+      inputRows.foreach { row =>
+        if (row.eventTime.before(minEventTime)) {
+          minEventTime = row.eventTime
+        }
+        count += 1
+      }
+      Iterator.single(OutputRow(key, minEventTime, count))
+    }
+  }
+}
+
+class InputCountStatefulProcessor[T]
+  extends StatefulProcessor[String, T, Int] {
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[T],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[Int] = {
+    Iterator.single(inputRows.size)
+  }
+}
+
+/**
+ * Emits output row with timestamp older than current watermark for batchId > 
0.
+ */
+class StatefulProcessorEmittingRowsOlderThanWatermark
+  extends StatefulProcessor[String, InputEventRow, OutputRow] {
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputEventRow],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputRow] = {
+    Iterator.single(
+      OutputRow(
+        key,
+        // always emit value with eventTime 1 which will fail after first 
batch, as
+        // watermark will move past 0L
+        Timestamp.from(Instant.ofEpochMilli(1)),
+        inputRows.size))
+  }
+}
+
+case class Window(
+    start: Timestamp,
+    end: Timestamp)
+
+case class AggEventRow(
+    window: Window,
+    count: Long)
+
+class TransformWithStateChainingSuite extends StreamTest {
+  import testImplicits._
+
+  test("watermark is propagated correctly for next stateful operator" +
+    " after transformWithState") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          "outputEventTime",
+          OutputMode.Append())
+        .groupBy(window($"outputEventTime", "1 minute"))
+        .count()
+        .as[AggEventRow]
+
+      testStream(result, OutputMode.Append())(
+        AddData(inputData, InputEventRow("k1", timestamp("2024-01-01 
00:00:00"), "e1")),
+        // watermark should be 1 minute behind `2024-01-01 00:00:00`, nothing 
is
+        // emitted as all records have timestamp > epoch
+        CheckNewAnswer(),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2023-12-31 23:59:00"))
+        },
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1")),
+        // global watermark should now be 1 minute behind  `2024-02-01 
00:00:00`.
+        CheckNewAnswer(AggEventRow(
+          Window(timestamp("2024-01-01 00:00:00"), timestamp("2024-01-01 
00:01:00")), 1)
+        ),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2024-01-31 23:59:00"))
+        },
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-02 
00:00:00"), "e1")),
+        CheckNewAnswer(AggEventRow(
+          Window(timestamp("2024-02-01 00:00:00"), timestamp("2024-02-01 
00:01:00")), 1)
+        )
+      )
+    }
+  }
+
+  test("passing eventTime column to transformWithState fails if" +
+    " no watermark is defined") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+
+      val ex = intercept[AnalysisException] {
+        inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          "outputEventTime",
+          OutputMode.Append())
+      }
+
+      checkError(ex, "CANNOT_ASSIGN_EVENT_TIME_COLUMN_WITHOUT_WATERMARK")
+    }
+  }
+
+  test("missing eventTime column to transformWithState fails the query if" +
+    " another stateful operator is added") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          TimeMode.None(),
+          OutputMode.Append())
+        .groupBy(window($"outputEventTime", "1 minute"))
+        .count()
+
+      val ex = intercept[ExtendedAnalysisException] {
+        testStream(result, OutputMode.Append())(
+          StartStream()
+        )
+      }
+      assert(ex.getMessage.contains("there are streaming aggregations on" +
+        " streaming DataFrames/DataSets without watermark"))
+    }
+  }
+
+  test("chaining multiple transformWithState operators") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          "outputEventTime",
+          OutputMode.Append())
+        .groupByKey(x => x.key)
+        .transformWithState(
+          new InputCountStatefulProcessor[OutputRow](),
+          TimeMode.None(),
+          OutputMode.Append()
+        )
+
+      testStream(result, OutputMode.Append())(
+        AddData(inputData, InputEventRow("k1", timestamp("2024-01-01 
00:00:00"), "e1")),
+        CheckNewAnswer(1),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2023-12-31 23:59:00"))
+        },
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1")),
+        CheckNewAnswer(1),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2024-01-31 23:59:00"))
+        },
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-02 
00:00:00"), "e1")),
+        CheckNewAnswer(1)
+      )
+    }
+  }
+
+  test("dropDuplicateWithWatermark after transformWithState operator" +
+    " fails if watermark column is not provided") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          TimeMode.None(),
+          OutputMode.Append())
+        .dropDuplicatesWithinWatermark()
+
+      val ex = intercept[ExtendedAnalysisException] {
+        testStream(result, OutputMode.Append())(
+          StartStream()
+        )
+      }
+      assert(ex.getMessage.contains("dropDuplicatesWithinWatermark is not 
supported on" +
+        " streaming DataFrames/DataSets without watermark"))
+    }
+  }
+
+  test("dropDuplicateWithWatermark after transformWithState operator") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          "outputEventTime",
+          OutputMode.Append())
+        .dropDuplicatesWithinWatermark()
+
+      testStream(result, OutputMode.Append())(
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1"),
+          InputEventRow("k1", timestamp("2024-02-01 00:00:00"), "e1")),
+        CheckNewAnswer(OutputRow("k1", timestamp("2024-02-01 00:00:00"), 2)),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2024-01-31 23:59:00"))
+        }
+      )
+    }
+  }
+
+  test("query fails if the output dataset does not contain specified 
eventTimeColumn") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+
+      val ex = intercept[ExtendedAnalysisException] {
+        val result = inputData.toDS()
+          .withWatermark("eventTime", "1 minute")
+          .groupByKey(x => x.key)
+          .transformWithState[OutputRow](
+            new TestStatefulProcessor(),
+            "missingEventTimeColumn",
+            OutputMode.Append())
+
+        testStream(result, OutputMode.Append())(
+          StartStream()
+        )
+      }
+
+      checkError(ex, "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+        parameters = Map(
+          "objectName" -> "`missingEventTimeColumn`",
+          "proposal" -> "`outputEventTime`, `count`, `key`"))
+    }
+  }
+
+  test("query fails if the output dataset contains rows older than current 
watermark") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+      val result = inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new StatefulProcessorEmittingRowsOlderThanWatermark(),
+          "outputEventTime",
+          OutputMode.Append())
+
+      testStream(result, OutputMode.Append())(
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1")),
+        // after first batch, the rows are emitted with timestamp 1 ms after 
epoch
+        CheckNewAnswer(OutputRow("k1", 
Timestamp.from(Instant.ofEpochMilli(1)), 1)),
+        // this batch would fail now, because watermark will move past 1ms 
after epoch
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-02 
00:00:00"), "e1")),
+        ExpectFailure[SparkRuntimeException] { ex =>
+          checkError(ex.asInstanceOf[SparkThrowable],
+            "EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED",
+            parameters = Map("currentWatermark" -> "1706774340000",
+              "emittedRowEventTime" -> "1000"))
+        }
+      )
+    }
+  }
+
+  test("ensure that watermark delay is resolved from a view") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+      inputData.toDS()
+        .withWatermark("eventTime", "1 minute")
+        .createTempView("tempViewWithWatermark")
+
+      val result = spark.readStream.table("tempViewWithWatermark")
+        .as[InputEventRow]
+        .groupByKey(x => x.key)
+        .transformWithState[OutputRow](
+          new TestStatefulProcessor(),
+          "outputEventTime",
+          OutputMode.Append())
+
+      testStream(result, OutputMode.Append())(
+        AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1")),
+        CheckNewAnswer(OutputRow("k1", timestamp("2024-02-01 00:00:00"), 1)),
+        Execute("assertWatermarkEquals") { q =>
+          assertWatermarkEquals(q, timestamp("2024-01-31 23:59:00"))
+        }
+      )
+    }
+  }
+
+  test("ensure that query fails if there is no watermark when reading from a 
view") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputData = MemoryStream[InputEventRow]
+      inputData.toDS()
+        .createTempView("tempViewWithoutWatermark")
+
+      val ex = intercept[AnalysisException] {
+        val result = spark.readStream.table("tempViewWithoutWatermark")
+          .as[InputEventRow]
+          .groupByKey(x => x.key)
+          .transformWithState[OutputRow](
+            new TestStatefulProcessor(),
+            "outputEventTime",
+            OutputMode.Append())
+
+        testStream(result, OutputMode.Append())(
+          AddData(inputData, InputEventRow("k1", timestamp("2024-02-01 
00:00:00"), "e1")),
+          ExpectFailure[SparkRuntimeException] { ex =>
+            checkError(ex.asInstanceOf[AnalysisException],
+              "CANNOT_ASSIGN_EVENT_TIME_COLUMN_WITHOUT_WATERMARK")
+          }
+        )
+      }
+
+      checkError(ex, "CANNOT_ASSIGN_EVENT_TIME_COLUMN_WITHOUT_WATERMARK")
+    }
+  }
+
+  private def timestamp(str: String): Timestamp = {
+    Timestamp.valueOf(str)
+  }
+
+  private def assertWatermarkEquals(q: StreamExecution, watermark: Timestamp): 
Unit = {
+    val queryWatermark = getQueryWatermark(q)
+    assert(queryWatermark.isDefined)
+    assert(queryWatermark.get === watermark)
+  }
+
+  private def getQueryWatermark(q: StreamExecution): Option[Timestamp] = {
+    import scala.jdk.CollectionConverters._
+    val eventTimeMap = q.lastProgress.eventTime.asScala
+    val queryWatermark = eventTimeMap.get("watermark")
+    queryWatermark.map { v =>
+      val instant = Instant.parse(v)
+      val local = LocalDateTime.ofInstant(instant, ZoneId.systemDefault())
+      Timestamp.valueOf(local)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to