Repository: spark
Updated Branches:
  refs/heads/master 2d73fcced -> c1e87e384


[SPARK-20030][SS] Event-time-based timeout for MapGroupsWithState

## What changes were proposed in this pull request?

Adding event time based timeout. The user sets the timeout timestamp directly 
using `KeyedState.setTimeoutTimestamp`. The keys times out when the watermark 
crosses the timeout timestamp.

## How was this patch tested?
Unit tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #17361 from tdas/SPARK-20030.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c1e87e38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c1e87e38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c1e87e38

Branch: refs/heads/master
Commit: c1e87e384d1878308b42da80bb3d65be512aab55
Parents: 2d73fcc
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue Mar 21 21:27:08 2017 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Mar 21 21:27:08 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/streaming/KeyedStateTimeout.java  |  22 +-
 .../analysis/UnsupportedOperationChecker.scala  |  96 +++--
 .../sql/catalyst/plans/logical/object.scala     |   3 +-
 .../analysis/UnsupportedOperationsSuite.scala   |  16 +
 .../spark/sql/execution/SparkStrategies.scala   |   3 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  |  87 ++--
 .../streaming/IncrementalExecution.scala        |   5 +-
 .../execution/streaming/KeyedStateImpl.scala    | 139 +++++--
 .../execution/streaming/statefulOperators.scala |  14 +-
 .../apache/spark/sql/streaming/KeyedState.scala |  97 ++++-
 .../streaming/FlatMapGroupsWithStateSuite.scala | 402 ++++++++++++-------
 11 files changed, 616 insertions(+), 268 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java
index cf112f2..e2e7ab1 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java
@@ -19,9 +19,7 @@ package org.apache.spark.sql.streaming;
 
 import org.apache.spark.annotation.Experimental;
 import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$;
-import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout;
-import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
+import org.apache.spark.sql.catalyst.plans.logical.*;
 
 /**
  * Represents the type of timeouts possible for the Dataset operations
@@ -34,9 +32,23 @@ import 
org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
 @InterfaceStability.Evolving
 public class KeyedStateTimeout {
 
-  /** Timeout based on processing time.  */
+  /**
+   * Timeout based on processing time. The duration of timeout can be set for 
each group in
+   * `map/flatMapGroupsWithState` by calling 
`KeyedState.setTimeoutDuration()`. See documentation
+   * on `KeyedState` for more details.
+   */
   public static KeyedStateTimeout ProcessingTimeTimeout() { return 
ProcessingTimeTimeout$.MODULE$; }
 
-  /** No timeout */
+  /**
+   * Timeout based on event-time. The event-time timestamp for timeout can be 
set for each
+   * group in `map/flatMapGroupsWithState` by calling 
`KeyedState.setTimeoutTimestamp()`.
+   * In addition, you have to define the watermark in the query using 
`Dataset.withWatermark`.
+   * When the watermark advances beyond the set timestamp of a group and the 
group has not
+   * received any data, then the group times out. See documentation on
+   * `KeyedState` for more details.
+   */
+  public static KeyedStateTimeout EventTimeTimeout() { return 
EventTimeTimeout$.MODULE$; }
+
+  /** No timeout. */
   public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index a9ff61e..7da7f55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -147,49 +147,69 @@ object UnsupportedOperationChecker {
           throwError("Commands like CreateTable*, AlterTable*, Show* are not 
supported with " +
             "streaming DataFrames/Datasets")
 
-        // mapGroupsWithState: Allowed only when no aggregation + Update 
output mode
-        case m: FlatMapGroupsWithState if m.isStreaming && 
m.isMapGroupsWithState =>
-          if (collectStreamingAggregates(plan).isEmpty) {
-            if (outputMode != InternalOutputModes.Update) {
-              throwError("mapGroupsWithState is not supported with " +
-                s"$outputMode output mode on a streaming DataFrame/Dataset")
-            } else {
-              // Allowed when no aggregation + Update output mode
-            }
-          } else {
-            throwError("mapGroupsWithState is not supported with aggregation " 
+
-              "on a streaming DataFrame/Dataset")
-          }
-
-        // flatMapGroupsWithState without aggregation
-        case m: FlatMapGroupsWithState
-          if m.isStreaming && collectStreamingAggregates(plan).isEmpty =>
-          m.outputMode match {
-            case InternalOutputModes.Update =>
-              if (outputMode != InternalOutputModes.Update) {
-                throwError("flatMapGroupsWithState in update mode is not 
supported with " +
+        // mapGroupsWithState and flatMapGroupsWithState
+        case m: FlatMapGroupsWithState if m.isStreaming =>
+
+          // Check compatibility with output modes and aggregations in query
+          val aggsAfterFlatMapGroups = collectStreamingAggregates(plan)
+
+          if (m.isMapGroupsWithState) {                       // check 
mapGroupsWithState
+            // allowed only in update query output mode and without aggregation
+            if (aggsAfterFlatMapGroups.nonEmpty) {
+              throwError(
+                "mapGroupsWithState is not supported with aggregation " +
+                  "on a streaming DataFrame/Dataset")
+            } else if (outputMode != InternalOutputModes.Update) {
+              throwError(
+                "mapGroupsWithState is not supported with " +
                   s"$outputMode output mode on a streaming DataFrame/Dataset")
+            }
+          } else {                                           // check 
latMapGroupsWithState
+            if (aggsAfterFlatMapGroups.isEmpty) {
+              // flatMapGroupsWithState without aggregation: operation's 
output mode must
+              // match query output mode
+              m.outputMode match {
+                case InternalOutputModes.Update if outputMode != 
InternalOutputModes.Update =>
+                  throwError(
+                    "flatMapGroupsWithState in update mode is not supported 
with " +
+                      s"$outputMode output mode on a streaming 
DataFrame/Dataset")
+
+                case InternalOutputModes.Append if outputMode != 
InternalOutputModes.Append =>
+                  throwError(
+                    "flatMapGroupsWithState in append mode is not supported 
with " +
+                      s"$outputMode output mode on a streaming 
DataFrame/Dataset")
+
+                case _ =>
               }
-            case InternalOutputModes.Append =>
-              if (outputMode != InternalOutputModes.Append) {
-                throwError("flatMapGroupsWithState in append mode is not 
supported with " +
-                  s"$outputMode output mode on a streaming DataFrame/Dataset")
+            } else {
+              // flatMapGroupsWithState with aggregation: update operation 
mode not allowed, and
+              // *groupsWithState after aggregation not allowed
+              if (m.outputMode == InternalOutputModes.Update) {
+                throwError(
+                  "flatMapGroupsWithState in update mode is not supported with 
" +
+                    "aggregation on a streaming DataFrame/Dataset")
+              } else if (collectStreamingAggregates(m).nonEmpty) {
+                throwError(
+                  "flatMapGroupsWithState in append mode is not supported 
after " +
+                    s"aggregation on a streaming DataFrame/Dataset")
               }
+            }
           }
 
-        // flatMapGroupsWithState(Update) with aggregation
-        case m: FlatMapGroupsWithState
-          if m.isStreaming && m.outputMode == InternalOutputModes.Update
-            && collectStreamingAggregates(plan).nonEmpty =>
-          throwError("flatMapGroupsWithState in update mode is not supported 
with " +
-            "aggregation on a streaming DataFrame/Dataset")
-
-        // flatMapGroupsWithState(Append) with aggregation
-        case m: FlatMapGroupsWithState
-          if m.isStreaming && m.outputMode == InternalOutputModes.Append
-            && collectStreamingAggregates(m).nonEmpty =>
-          throwError("flatMapGroupsWithState in append mode is not supported 
after " +
-            s"aggregation on a streaming DataFrame/Dataset")
+          // Check compatibility with timeout configs
+          if (m.timeout == EventTimeTimeout) {
+            // With event time timeout, watermark must be defined.
+            val watermarkAttributes = m.child.output.collect {
+              case a: Attribute if 
a.metadata.contains(EventTimeWatermark.delayKey) => a
+            }
+            if (watermarkAttributes.isEmpty) {
+              throwError(
+                "Watermark must be specified in the query using " +
+                  "'[Dataset/DataFrame].withWatermark()' for using event-time 
timeout in a " +
+                  "[map|flatMap]GroupsWithState. Event-time timeout not 
supported without " +
+                  "watermark.")(plan)
+            }
+          }
 
         case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
           throwError("dropDuplicates is not supported after aggregation on a " 
+

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index d1f95fa..e0ecf8c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -353,9 +353,10 @@ case class MapGroups(
 /** Internal class representing State */
 trait LogicalKeyedState[S]
 
-/** Possible types of timeouts used in FlatMapGroupsWithState */
+/** Types of timeouts used in FlatMapGroupsWithState */
 case object NoTimeout extends KeyedStateTimeout
 case object ProcessingTimeTimeout extends KeyedStateTimeout
+case object EventTimeTimeout extends KeyedStateTimeout
 
 /** Factory for constructing new `MapGroupsWithState` nodes. */
 object FlatMapGroupsWithState {

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 08216e2..8f0a0c0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -345,6 +345,22 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     outputMode = Append,
     expectedMsgs = Seq("Mixing mapGroupsWithStates and 
flatMapGroupsWithStates"))
 
+  // mapGroupsWithState with event time timeout + watermark
+  assertNotSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState with event time timeout without 
watermark",
+    FlatMapGroupsWithState(
+      null, att, att, Seq(att), Seq(att), att, null, Update, 
isMapGroupsWithState = true,
+      EventTimeTimeout, streamRelation),
+    outputMode = Update,
+    expectedMsgs = Seq("watermark"))
+
+  assertSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState with event time timeout with 
watermark",
+    FlatMapGroupsWithState(
+      null, att, att, Seq(att), Seq(att), att, null, Update, 
isMapGroupsWithState = true,
+      EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)),
+    outputMode = Update)
+
   // Deduplicate
   assertSupportedInStreamingPlan(
     "Deduplicate - Deduplicate on streaming relation before aggregation",

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9e58e8c..ca2f6dd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -336,8 +336,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         timeout, child) =>
         val execPlan = FlatMapGroupsWithStateExec(
           func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, 
stateEnc, outputMode,
-          timeout, batchTimestampMs = 
KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP,
-          planLater(child))
+          timeout, batchTimestampMs = None, eventTimeWatermark = None, 
planLater(child))
         execPlan :: Nil
       case _ =>
         Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 991d8ef..52ad70c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.streaming
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, 
UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, 
ProcessingTimeTimeout}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
AttributeReference, Expression, Literal, SortOrder, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution}
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
-import org.apache.spark.sql.types.{BooleanType, IntegerType}
+import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.util.CompletionIterator
 
 /**
@@ -39,7 +40,7 @@ import org.apache.spark.util.CompletionIterator
  * @param outputObjAttr used to define the output object
  * @param stateEncoder used to serialize/deserialize state before calling 
`func`
  * @param outputMode the output mode of `func`
- * @param timeout used to timeout groups that have not received data in a while
+ * @param timeoutConf used to timeout groups that have not received data in a 
while
  * @param batchTimestampMs processing timestamp of the current batch.
  */
 case class FlatMapGroupsWithStateExec(
@@ -52,11 +53,15 @@ case class FlatMapGroupsWithStateExec(
     stateId: Option[OperatorStateId],
     stateEncoder: ExpressionEncoder[Any],
     outputMode: OutputMode,
-    timeout: KeyedStateTimeout,
-    batchTimestampMs: Long,
-    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with 
StateStoreWriter {
+    timeoutConf: KeyedStateTimeout,
+    batchTimestampMs: Option[Long],
+    override val eventTimeWatermark: Option[Long],
+    child: SparkPlan
+  ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with 
WatermarkSupport {
 
-  private val isTimeoutEnabled = timeout == ProcessingTimeTimeout
+  import KeyedStateImpl._
+
+  private val isTimeoutEnabled = timeoutConf != NoTimeout
   private val timestampTimeoutAttribute =
     AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = 
false)()
   private val stateAttributes: Seq[Attribute] = {
@@ -64,8 +69,6 @@ case class FlatMapGroupsWithStateExec(
     if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else 
encSchemaAttribs
   }
 
-  import KeyedStateImpl._
-
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(groupingAttributes) :: Nil
@@ -74,9 +77,21 @@ case class FlatMapGroupsWithStateExec(
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     Seq(groupingAttributes.map(SortOrder(_, Ascending)))
 
+  override def keyExpressions: Seq[Attribute] = groupingAttributes
+
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
+    // Throw errors early if parameters are not as expected
+    timeoutConf match {
+      case ProcessingTimeTimeout =>
+        require(batchTimestampMs.nonEmpty)
+      case EventTimeTimeout =>
+        require(eventTimeWatermark.nonEmpty)  // watermark value has been 
populated
+        require(watermarkExpression.nonEmpty) // input schema has watermark 
attribute
+      case _ =>
+    }
+
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateId.checkpointLocation,
       getStateId.operatorId,
@@ -84,15 +99,23 @@ case class FlatMapGroupsWithStateExec(
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
       sqlContext.sessionState,
-      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) 
=>
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val updater = new StateStoreUpdater(store)
 
+        // If timeout is based on event time, then filter late data based on 
watermark
+        val filteredIter = watermarkPredicateForData match {
+          case Some(predicate) if timeoutConf == EventTimeTimeout =>
+            iter.filter(row => !predicate.eval(row))
+          case None =>
+            iter
+        }
+
         // Generate a iterator that returns the rows grouped by the grouping 
function
         // Note that this code ensures that the filtering for timeout occurs 
only after
         // all the data has been processed. This is to ensure that the timeout 
information of all
         // the keys with data is updated before they are processed for 
timeouts.
         val outputIterator =
-          updater.updateStateForKeysWithData(iterator) ++ 
updater.updateStateForTimedOutKeys()
+          updater.updateStateForKeysWithData(filteredIter) ++ 
updater.updateStateForTimedOutKeys()
 
         // Return an iterator of all the rows generated by all the keys, such 
that when fully
         // consumed, all the state updates will be committed by the state store
@@ -124,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
     private val stateSerializer = {
       val encoderSerializer = stateEncoder.namedExpressions
       if (isTimeoutEnabled) {
-        encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET)
+        encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP)
       } else {
         encoderSerializer
       }
@@ -157,16 +180,19 @@ case class FlatMapGroupsWithStateExec(
     /** Find the groups that have timeout set and are timing out right now, 
and call the function */
     def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
       if (isTimeoutEnabled) {
+        val timeoutThreshold = timeoutConf match {
+          case ProcessingTimeTimeout => batchTimestampMs.get
+          case EventTimeTimeout => eventTimeWatermark.get
+          case _ =>
+            throw new IllegalStateException(
+              s"Cannot filter timed out keys for $timeoutConf")
+        }
         val timingOutKeys = store.filter { case (_, stateRow) =>
           val timeoutTimestamp = getTimeoutTimestamp(stateRow)
-          timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < 
batchTimestampMs
+          timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < 
timeoutThreshold
         }
         timingOutKeys.flatMap { case (keyRow, stateRow) =>
-          callFunctionAndUpdateState(
-            keyRow,
-            Iterator.empty,
-            Some(stateRow),
-            hasTimedOut = true)
+          callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), 
hasTimedOut = true)
         }
       } else Iterator.empty
     }
@@ -186,7 +212,11 @@ case class FlatMapGroupsWithStateExec(
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value 
rows to objects
       val stateObjOption = getStateObj(prevStateRowOption)
       val keyedState = new KeyedStateImpl(
-        stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut)
+        stateObjOption,
+        batchTimestampMs.getOrElse(NO_TIMESTAMP),
+        eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+        timeoutConf,
+        hasTimedOut)
 
       // Call function, get the returned objects and convert them to rows
       val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
@@ -196,8 +226,6 @@ case class FlatMapGroupsWithStateExec(
 
       // When the iterator is consumed, then write changes to state
       def onIteratorCompletion: Unit = {
-        // Has the timeout information changed
-
         if (keyedState.hasRemoved) {
           store.remove(keyRow)
           numUpdatedStateRows += 1
@@ -205,26 +233,25 @@ case class FlatMapGroupsWithStateExec(
         } else {
           val previousTimeoutTimestamp = prevStateRowOption match {
             case Some(row) => getTimeoutTimestamp(row)
-            case None => TIMEOUT_TIMESTAMP_NOT_SET
+            case None => NO_TIMESTAMP
           }
-
+          val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
           val stateRowToWrite = if (keyedState.hasUpdated) {
             getStateRow(keyedState.get)
           } else {
             prevStateRowOption.orNull
           }
 
-          val hasTimeoutChanged = keyedState.getTimeoutTimestamp != 
previousTimeoutTimestamp
+          val hasTimeoutChanged = currentTimeoutTimestamp != 
previousTimeoutTimestamp
           val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
 
           if (shouldWriteState) {
             if (stateRowToWrite == null) {
               // This should never happen because checks in KeyedStateImpl 
should avoid cases
               // where empty state would need to be written
-              throw new IllegalStateException(
-                "Attempting to write empty state")
+              throw new IllegalStateException("Attempting to write empty 
state")
             }
-            setTimeoutTimestamp(stateRowToWrite, 
keyedState.getTimeoutTimestamp)
+            setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
             store.put(keyRow.copy(), stateRowToWrite.copy())
             numUpdatedStateRows += 1
           }
@@ -247,7 +274,7 @@ case class FlatMapGroupsWithStateExec(
 
     /** Returns the timeout timestamp of a state row is set */
     def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else 
TIMEOUT_TIMESTAMP_NOT_SET
+      if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else 
NO_TIMESTAMP
     }
 
     /** Set the timestamp in a state row */

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index a934c75..0f0e4a9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -108,7 +108,10 @@ class IncrementalExecution(
       case m: FlatMapGroupsWithStateExec =>
         val stateId =
           OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), 
currentBatchId)
-        m.copy(stateId = Some(stateId), batchTimestampMs = 
offsetSeqMetadata.batchTimestampMs)
+        m.copy(
+          stateId = Some(stateId),
+          batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
+          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
index ac421d3..edfd35b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
@@ -17,37 +17,45 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.sql.Date
+
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.sql.streaming.KeyedState
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl._
+import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout}
 import org.apache.spark.unsafe.types.CalendarInterval
 
+
 /**
  * Internal implementation of the [[KeyedState]] interface. Methods are not 
thread-safe.
  * @param optionalValue Optional value of the state
  * @param batchProcessingTimeMs Processing time of current batch, used to 
calculate timestamp
  *                              for processing time timeouts
- * @param isTimeoutEnabled Whether timeout is enabled. This will be used to 
check whether the user
- *                         is allowed to configure timeouts.
+ * @param timeoutConf     Type of timeout configured. Based on this, different 
operations will
+ *                        be supported.
  * @param hasTimedOut     Whether the key for which this state wrapped is 
being created is
  *                        getting timed out or not.
  */
 private[sql] class KeyedStateImpl[S](
     optionalValue: Option[S],
     batchProcessingTimeMs: Long,
-    isTimeoutEnabled: Boolean,
+    eventTimeWatermarkMs: Long,
+    timeoutConf: KeyedStateTimeout,
     override val hasTimedOut: Boolean) extends KeyedState[S] {
 
-  import KeyedStateImpl._
-
   // Constructor to create dummy state when using mapGroupsWithState in a 
batch query
   def this(optionalValue: Option[S]) = this(
-    optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false)
+    optionalValue,
+    batchProcessingTimeMs = NO_TIMESTAMP,
+    eventTimeWatermarkMs = NO_TIMESTAMP,
+    timeoutConf = KeyedStateTimeout.NoTimeout,
+    hasTimedOut = false)
   private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
   private var defined: Boolean = optionalValue.isDefined
   private var updated: Boolean = false // whether value has been updated (but 
not removed)
   private var removed: Boolean = false // whether value has been removed
-  private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET
+  private var timeoutTimestamp: Long = NO_TIMESTAMP
 
   // ========= Public API =========
   override def exists: Boolean = defined
@@ -82,13 +90,14 @@ private[sql] class KeyedStateImpl[S](
     defined = false
     updated = false
     removed = true
-    timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET
+    timeoutTimestamp = NO_TIMESTAMP
   }
 
   override def setTimeoutDuration(durationMs: Long): Unit = {
-    if (!isTimeoutEnabled) {
+    if (timeoutConf != ProcessingTimeTimeout) {
       throw new UnsupportedOperationException(
-        "Cannot set timeout information without enabling timeout in 
map/flatMapGroupsWithState")
+        "Cannot set timeout duration without enabling processing time timeout 
in " +
+          "map/flatMapGroupsWithState")
     }
     if (!defined) {
       throw new IllegalStateException(
@@ -99,7 +108,7 @@ private[sql] class KeyedStateImpl[S](
     if (durationMs <= 0) {
       throw new IllegalArgumentException("Timeout duration must be positive")
     }
-    if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) {
+    if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
       timeoutTimestamp = durationMs + batchProcessingTimeMs
     } else {
       // This is being called in a batch query, hence no processing timestamp.
@@ -108,29 +117,55 @@ private[sql] class KeyedStateImpl[S](
   }
 
   override def setTimeoutDuration(duration: String): Unit = {
-    if (StringUtils.isBlank(duration)) {
-      throw new IllegalArgumentException(
-        "The window duration, slide duration and start time cannot be null or 
blank.")
-    }
-    val intervalString = if (duration.startsWith("interval")) {
-      duration
-    } else {
-      "interval " + duration
+    setTimeoutDuration(parseDuration(duration))
+  }
+
+  @throws[IllegalArgumentException]("if 'timestampMs' is not positive")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  override def setTimeoutTimestamp(timestampMs: Long): Unit = {
+    checkTimeoutTimestampAllowed()
+    if (timestampMs <= 0) {
+      throw new IllegalArgumentException("Timeout timestamp must be positive")
     }
-    val cal = CalendarInterval.fromString(intervalString)
-    if (cal == null) {
+    if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < 
eventTimeWatermarkMs) {
       throw new IllegalArgumentException(
-        s"The provided duration ($duration) is not valid.")
+        s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
+          s"current watermark ($eventTimeWatermarkMs)")
     }
-    if (cal.milliseconds < 0 || cal.months < 0) {
-      throw new IllegalArgumentException("Timeout duration must be positive")
+    if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
+      timeoutTimestamp = timestampMs
+    } else {
+      // This is being called in a batch query, hence no processing timestamp.
+      // Just ignore any attempts to set timeout.
     }
+  }
 
-    val delayMs = {
-      val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
-      cal.milliseconds + cal.months * millisPerMonth
-    }
-    setTimeoutDuration(delayMs)
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: 
String): Unit = {
+    checkTimeoutTimestampAllowed()
+    setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs)
+  }
+
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  override def setTimeoutTimestamp(timestamp: Date): Unit = {
+    checkTimeoutTimestampAllowed()
+    setTimeoutTimestamp(timestamp.getTime)
+  }
+
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  override def setTimeoutTimestamp(timestamp: Date, additionalDuration: 
String): Unit = {
+    checkTimeoutTimestampAllowed()
+    setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration))
   }
 
   override def toString: String = {
@@ -147,14 +182,46 @@ private[sql] class KeyedStateImpl[S](
 
   /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */
   def getTimeoutTimestamp: Long = timeoutTimestamp
+
+  private def parseDuration(duration: String): Long = {
+    if (StringUtils.isBlank(duration)) {
+      throw new IllegalArgumentException(
+        "Provided duration is null or blank.")
+    }
+    val intervalString = if (duration.startsWith("interval")) {
+      duration
+    } else {
+      "interval " + duration
+    }
+    val cal = CalendarInterval.fromString(intervalString)
+    if (cal == null) {
+      throw new IllegalArgumentException(
+        s"Provided duration ($duration) is not valid.")
+    }
+    if (cal.milliseconds < 0 || cal.months < 0) {
+      throw new IllegalArgumentException(s"Provided duration ($duration) is 
not positive")
+    }
+
+    val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
+    cal.milliseconds + cal.months * millisPerMonth
+  }
+
+  private def checkTimeoutTimestampAllowed(): Unit = {
+    if (timeoutConf != EventTimeTimeout) {
+      throw new UnsupportedOperationException(
+        "Cannot set timeout timestamp without enabling event time timeout in " 
+
+          "map/flatMapGroupsWithState")
+    }
+    if (!defined) {
+      throw new IllegalStateException(
+        "Cannot set timeout timestamp without any state value, " +
+          "state has either not been initialized, or has already been removed")
+    }
+  }
 }
 
 
 private[sql] object KeyedStateImpl {
-  // Value used in the state row to represent the lack of any timeout timestamp
-  val TIMEOUT_TIMESTAMP_NOT_SET = -1L
-
-  // Value to represent that no batch processing timestamp is passed to 
KeyedStateImpl. This is
-  // used in batch queries where there are no streaming batches and timeouts.
-  val NO_BATCH_PROCESSING_TIMESTAMP = -1L
+  // Value used represent the lack of valid timestamp as a long
+  val NO_TIMESTAMP = -1L
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 6d2de44..f72144a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -80,7 +80,7 @@ trait WatermarkSupport extends UnaryExecNode {
   /** Generate an expression that matches data older than the watermark */
   lazy val watermarkExpression: Option[Expression] = {
     val optionalWatermarkAttribute =
-      keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
+      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey))
 
     optionalWatermarkAttribute.map { watermarkAttribute =>
       // If we are evicting based on a window, use the end of the window.  
Otherwise just
@@ -101,14 +101,12 @@ trait WatermarkSupport extends UnaryExecNode {
     }
   }
 
-  /** Generate a predicate based on keys that matches data older than the 
watermark */
+  /** Predicate based on keys that matches data older than the watermark */
   lazy val watermarkPredicateForKeys: Option[Predicate] =
     watermarkExpression.map(newPredicate(_, keyExpressions))
 
-  /**
-   * Generate a predicate based on the child output that matches data older 
than the watermark.
-   */
-  lazy val watermarkPredicate: Option[Predicate] =
+  /** Predicate based on the child output that matches data older than the 
watermark. */
+  lazy val watermarkPredicateForData: Option[Predicate] =
     watermarkExpression.map(newPredicate(_, child.output))
 }
 
@@ -218,7 +216,7 @@ case class StateStoreSaveExec(
             new Iterator[InternalRow] {
 
               // Filter late date using watermark if specified
-              private[this] val baseIterator = watermarkPredicate match {
+              private[this] val baseIterator = watermarkPredicateForData match 
{
                 case Some(predicate) => iter.filter((row: InternalRow) => 
!predicate.eval(row))
                 case None => iter
               }
@@ -285,7 +283,7 @@ case class StreamingDeduplicateExec(
       val numTotalStateRows = longMetric("numTotalStateRows")
       val numUpdatedStateRows = longMetric("numUpdatedStateRows")
 
-      val baseIterator = watermarkPredicate match {
+      val baseIterator = watermarkPredicateForData match {
         case Some(predicate) => iter.filter(row => !predicate.eval(row))
         case None => iter
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala
index 6b4b1ce..461de04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala
@@ -55,7 +55,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
  *    batch, nor with streaming Datasets.
  *  - All the data will be shuffled before applying the function.
  *  - If timeout is set, then the function will also be called with no values.
- *    See more details on KeyedStateTimeout` below.
+ *    See more details on `KeyedStateTimeout` below.
  *
  * Important points to note about using `KeyedState`.
  *  - The value of the state cannot be null. So updating state with null will 
throw
@@ -68,20 +68,38 @@ import 
org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
  *
  * Important points to note about using `KeyedStateTimeout`.
  *  - The timeout type is a global param across all the keys (set as `timeout` 
param in
- *    `[map|flatMap]GroupsWithState`, but the exact timeout duration is 
configurable per key
- *    (by calling `setTimeout...()` in `KeyedState`).
- *  - When the timeout occurs for a key, the function is called with no 
values, and
+ *    `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp 
is configurable per
+ *    key by calling `setTimeout...()` in `KeyedState`.
+ *  - Timeouts can be either based on processing time (i.e.
+ *    [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e.
+ *    [[KeyedStateTimeout.EventTimeTimeout]]).
+ *  - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
+ *    `KeyedState.setTimeoutDuration`. The timeout will occur when the clock 
has advanced by the set
+ *    duration. Guarantees provided by this timeout with a duration of D ms 
are as follows:
+ *    - Timeout will never be occur before the clock time has advanced by D ms
+ *    - Timeout will occur eventually when there is a trigger in the query
+ *      (i.e. after D ms). So there is a no strict upper bound on when the 
timeout would occur.
+ *      For example, the trigger interval of the query will affect when the 
timeout actually occurs.
+ *      If there is no data in the stream (for any key) for a while, then 
their will not be
+ *      any trigger and timeout function call will not occur until there is 
data.
+ *    - Since the processing time timeout is based on the clock time, it is 
affected by the
+ *      variations in the system clock (i.e. time zone changes, clock skew, 
etc.).
+ *  - With `EventTimeTimeout`, the user also has to specify the the the event 
time watermark in
+ *    the query using `Dataset.withWatermark()`. With this setting, data that 
is older than the
+ *    watermark are filtered out. The timeout can be enabled for a key by 
setting a timestamp using
+ *    `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the 
watermark advances
+ *    beyond the set timestamp. You can control the timeout delay by two 
parameters - (i) watermark
+ *    delay and an additional duration beyond the timestamp in the event 
(which is guaranteed to
+ *    > watermark due to the filtering). Guarantees provided by this timeout 
are as follows:
+ *    - Timeout will never be occur before watermark has exceeded the set 
timeout.
+ *    - Similar to processing time timeouts, there is a no strict upper bound 
on the delay when
+ *      the timeout actually occurs. The watermark can advance only when there 
is data in the
+ *      stream, and the event time of the data has actually advanced.
+ *  - When the timeout occurs for a key, the function is called for that key 
with no values, and
  *    `KeyedState.hasTimedOut()` set to true.
  *  - The timeout is reset for key every time the function is called on the 
key, that is,
  *    when the key has new data, or the key has timed out. So the user has to 
set the timeout
  *    duration every time the function is called, otherwise there will not be 
any timeout set.
- *  - Guarantees provided on processing-time-based timeout of key, when 
timeout duration is D ms:
- *    - Timeout will never be called before real clock time has advanced by D 
ms
- *    - Timeout will be called eventually when there is a trigger in the query
- *      (i.e. after D ms). So there is a no strict upper bound on when the 
timeout would occur.
- *      For example, the trigger interval of the query will affect when the 
timeout is actually hit.
- *      If there is no data in the stream (for any key) for a while, then 
their will not be
- *      any trigger and timeout will not be hit until there is data.
  *
  * Scala example of using KeyedState in `mapGroupsWithState`:
  * {{{
@@ -194,7 +212,8 @@ trait KeyedState[S] extends LogicalKeyedState[S] {
 
   /**
    * Set the timeout duration in ms for this key.
-   * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`.
+   *
+   * @note ProcessingTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
    */
   @throws[IllegalArgumentException]("if 'durationMs' is not positive")
   @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
@@ -204,11 +223,63 @@ trait KeyedState[S] extends LogicalKeyedState[S] {
 
   /**
    * Set the timeout duration for this key as a string. For example, "1 hour", 
"2 days", etc.
-   * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`.
+   *
+   * @note, ProcessingTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
    */
   @throws[IllegalArgumentException]("if 'duration' is not a valid duration")
   @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
   @throws[UnsupportedOperationException](
     "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
   def setTimeoutDuration(duration: String): Unit
+
+  @throws[IllegalArgumentException]("if 'timestampMs' is not positive")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  /**
+   * Set the timeout timestamp for this key as milliseconds in epoch time.
+   * This timestamp cannot be older than the current watermark.
+   *
+   * @note, EventTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   */
+  def setTimeoutTimestamp(timestampMs: Long): Unit
+
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  /**
+   * Set the timeout timestamp for this key as milliseconds in epoch time and 
an additional
+   * duration as a string (e.g. "1 hour", "2 days", etc.).
+   * The final timestamp (including the additional duration) cannot be older 
than the
+   * current watermark.
+   *
+   * @note, EventTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   */
+  def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit
+
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  /**
+   * Set the timeout timestamp for this key as a java.sql.Date.
+   * This timestamp cannot be older than the current watermark.
+   *
+   * @note, EventTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   */
+  def setTimeoutTimestamp(timestamp: java.sql.Date): Unit
+
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
+  @throws[UnsupportedOperationException](
+    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+  /**
+   * Set the timeout timestamp for this key as a java.sql.Date and an 
additional
+   * duration as a string (e.g. "1 hour", "2 days", etc.).
+   * The final timestamp (including the additional duration) cannot be older 
than the
+   * current watermark.
+   *
+   * @note, EventTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   */
+  def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: 
String): Unit
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c1e87e38/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 7daa5e6..fe72283 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.streaming
 
-import java.util
+import java.sql.Date
 import java.util.concurrent.ConcurrentHashMap
 
 import org.scalatest.BeforeAndAfterAll
@@ -44,6 +44,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
 
   import testImplicits._
   import KeyedStateImpl._
+  import KeyedStateTimeout._
 
   override def afterAll(): Unit = {
     super.afterAll()
@@ -96,77 +97,93 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     }
   }
 
-  test("KeyedState - setTimeoutDuration, hasTimedOut") {
-    import KeyedStateImpl._
-    var state: KeyedStateImpl[Int] = null
-
-    // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed
+  test("KeyedState - setTimeout**** with NoTimeout") {
     for (initState <- Seq(None, Some(5))) {
       // for different initial state
-      state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, 
hasTimedOut = false)
-      assert(state.hasTimedOut === false)
-      assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
-      intercept[UnsupportedOperationException] {
-        state.setTimeoutDuration(1000)
-      }
-      intercept[UnsupportedOperationException] {
-        state.setTimeoutDuration("1 day")
-      }
-      assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+      implicit val state = new KeyedStateImpl(initState, 1000, 1000, 
NoTimeout, hasTimedOut = false)
+      testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+      testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
     }
+  }
 
-    def testTimeoutNotAllowed(): Unit = {
-      intercept[IllegalStateException] {
-        state.setTimeoutDuration(1000)
-      }
-      assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
-      intercept[IllegalStateException] {
-        state.setTimeoutDuration("2 second")
-      }
-      assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
-    }
+  test("KeyedState - setTimeout**** with ProcessingTimeTimeout") {
+    implicit var state: KeyedStateImpl[Int] = null
 
-    // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed 
until the
-    // state is be defined
-    state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, 
hasTimedOut = false)
-    assert(state.hasTimedOut === false)
-    assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
-    testTimeoutNotAllowed()
+    state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, 
hasTimedOut = false)
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    testTimeoutDurationNotAllowed[IllegalStateException](state)
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
-    // After state has been set, setTimeoutDuration() is allowed, and
-    // getTimeoutTimestamp returned correct timestamp
     state.update(5)
-    assert(state.hasTimedOut === false)
-    assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     state.setTimeoutDuration(1000)
     assert(state.getTimeoutTimestamp === 2000)
     state.setTimeoutDuration("2 second")
     assert(state.getTimeoutTimestamp === 3000)
-    assert(state.hasTimedOut === false)
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+    state.remove()
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    testTimeoutDurationNotAllowed[IllegalStateException](state)
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+  }
+
+  test("KeyedState - setTimeout**** with EventTimeTimeout") {
+    implicit val state = new KeyedStateImpl[Int](
+      None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+    testTimeoutTimestampNotAllowed[IllegalStateException](state)
+
+    state.update(5)
+    state.setTimeoutTimestamp(10000)
+    assert(state.getTimeoutTimestamp === 10000)
+    state.setTimeoutTimestamp(new Date(20000))
+    assert(state.getTimeoutTimestamp === 20000)
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+
+    state.remove()
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+    testTimeoutTimestampNotAllowed[IllegalStateException](state)
+  }
+
+  test("KeyedState - illegal params to setTimeout****") {
+    var state: KeyedStateImpl[Int] = null
 
-    // setTimeoutDuration() with negative values or 0 is not allowed
+    // Test setTimeout****() with illegal values
     def testIllegalTimeout(body: => Unit): Unit = {
       intercept[IllegalArgumentException] { body }
-      assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+      assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     }
-    state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, 
hasTimedOut = false)
+
+    state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, 
hasTimedOut = false)
     testIllegalTimeout { state.setTimeoutDuration(-1000) }
     testIllegalTimeout { state.setTimeoutDuration(0) }
     testIllegalTimeout { state.setTimeoutDuration("-2 second") }
     testIllegalTimeout { state.setTimeoutDuration("-1 month") }
     testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
 
-    // Test remove() clear timeout timestamp, and setTimeoutDuration() is not 
allowed after that
-    state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, 
hasTimedOut = false)
-    state.remove()
-    assert(state.hasTimedOut === false)
-    assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
-    testTimeoutNotAllowed()
-
-    // Test hasTimedOut = true
-    state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, 
hasTimedOut = true)
-    assert(state.hasTimedOut === true)
-    assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+    state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, 
hasTimedOut = false)
+    testIllegalTimeout { state.setTimeoutTimestamp(-10000) }
+    testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") }
+    testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") }
+    testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") }
+    testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) }
+    testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 
second") }
+    testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 
month") }
+    testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month 
-1 day") }
+  }
+
+  test("KeyedState - hasTimedOut") {
+    for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, 
EventTimeTimeout)) {
+      for (initState <- Seq(None, Some(5))) {
+        val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, 
hasTimedOut = false)
+        assert(state1.hasTimedOut === false)
+        val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, 
hasTimedOut = true)
+        assert(state2.hasTimedOut === true)
+      }
+    }
   }
 
   test("KeyedState - primitive type") {
@@ -187,133 +204,186 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   }
 
   // Values used for testing StateStoreUpdater
-  val currentTimestamp = 1000
-  val beforeCurrentTimestamp = 999
-  val afterCurrentTimestamp = 1001
+  val currentBatchTimestamp = 1000
+  val currentBatchWatermark = 1000
+  val beforeTimeoutThreshold = 999
+  val afterTimeoutThreshold = 1001
+
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is 
disabled
+  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = 
NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no 
prior state"
-    val testName = s"timeout disabled - $priorStateStr - "
+    val testName = s"NoTimeout - $priorStateStr - "
 
     testStateUpdateWithData(
       testName + "no update",
       stateUpdates = state => { /* do nothing */ },
-      timeoutType = KeyedStateTimeout.NoTimeout,
+      timeoutConf = KeyedStateTimeout.NoTimeout,
       priorState = priorState,
       expectedState = priorState)    // should not change
 
     testStateUpdateWithData(
       testName + "state updated",
       stateUpdates = state => { state.update(5) },
-      timeoutType = KeyedStateTimeout.NoTimeout,
+      timeoutConf = KeyedStateTimeout.NoTimeout,
       priorState = priorState,
       expectedState = Some(5))     // should change
 
     testStateUpdateWithData(
       testName + "state removed",
       stateUpdates = state => { state.remove() },
-      timeoutType = KeyedStateTimeout.NoTimeout,
+      timeoutConf = KeyedStateTimeout.NoTimeout,
       priorState = priorState,
       expectedState = None)        // should be removed
   }
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is 
enabled
+  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != 
NoTimeout
   for (priorState <- Seq(None, Some(0))) {
-    for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) {
-      var testName = s"timeout enabled - "
+    for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
+      var testName = s""
       if (priorState.nonEmpty) {
         testName += "prior state set, "
         if (priorTimeoutTimestamp == 1000) {
-          testName += "prior timeout set - "
+          testName += "prior timeout set"
         } else {
-          testName += "no prior timeout - "
+          testName += "no prior timeout"
         }
       } else {
-        testName += "no prior state - "
+        testName += "no prior state"
+      }
+      for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
+
+        testStateUpdateWithData(
+          s"$timeoutConf - $testName - no update",
+          stateUpdates = state => { /* do nothing */ },
+          timeoutConf = timeoutConf,
+          priorState = priorState,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = priorState,                           // state 
should not change
+          expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset
+
+        testStateUpdateWithData(
+          s"$timeoutConf - $testName - state updated",
+          stateUpdates = state => { state.update(5) },
+          timeoutConf = timeoutConf,
+          priorState = priorState,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = Some(5),                              // state 
should change
+          expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset
+
+        testStateUpdateWithData(
+          s"$timeoutConf - $testName - state removed",
+          stateUpdates = state => { state.remove() },
+          timeoutConf = timeoutConf,
+          priorState = priorState,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = None)                                 // state 
should be removed
       }
 
       testStateUpdateWithData(
-        testName + "no update",
-        stateUpdates = state => { /* do nothing */ },
-        timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
-        priorState = priorState,
-        priorTimeoutTimestamp = priorTimeoutTimestamp,
-        expectedState = priorState,                           // state should 
not change
-        expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp 
should be reset
-
-      testStateUpdateWithData(
-        testName + "state updated",
-        stateUpdates = state => { state.update(5) },
-        timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+        s"ProcessingTimeTimeout - $testName - state and timeout duration 
updated",
+        stateUpdates =
+          (state: KeyedState[Int]) => { state.update(5); 
state.setTimeoutDuration(5000) },
+        timeoutConf = ProcessingTimeTimeout,
         priorState = priorState,
         priorTimeoutTimestamp = priorTimeoutTimestamp,
-        expectedState = Some(5),                              // state should 
change
-        expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp 
should be reset
+        expectedState = Some(5),                                 // state 
should change
+        expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp 
should change
 
       testStateUpdateWithData(
-        testName + "state removed",
-        stateUpdates = state => { state.remove() },
-        timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+        s"EventTimeTimeout - $testName - state and timeout timestamp updated",
+        stateUpdates =
+          (state: KeyedState[Int]) => { state.update(5); 
state.setTimeoutTimestamp(5000) },
+        timeoutConf = EventTimeTimeout,
         priorState = priorState,
         priorTimeoutTimestamp = priorTimeoutTimestamp,
-        expectedState = None)                                 // state should 
be removed
+        expectedState = Some(5),                                 // state 
should change
+        expectedTimeoutTimestamp = 5000)                         // timestamp 
should change
 
       testStateUpdateWithData(
-        testName + "timeout and state updated",
-        stateUpdates = state => { state.update(5); 
state.setTimeoutDuration(5000) },
-        timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+        s"EventTimeTimeout - $testName - timeout timestamp updated to before 
watermark",
+        stateUpdates =
+          (state: KeyedState[Int]) => {
+            state.update(5)
+            intercept[IllegalArgumentException] {
+              state.setTimeoutTimestamp(currentBatchWatermark - 1)  // try to 
set to < watermark
+            }
+          },
+        timeoutConf = EventTimeTimeout,
         priorState = priorState,
         priorTimeoutTimestamp = priorTimeoutTimestamp,
-        expectedState = Some(5),                              // state should 
change
-        expectedTimeoutTimestamp = currentTimestamp + 5000)   // timestamp 
should change
+        expectedState = Some(5),                                 // state 
should change
+        expectedTimeoutTimestamp = NO_TIMESTAMP)                 // timestamp 
should not update
     }
   }
 
   // Tests for StateStoreUpdater.updateStateForTimedOutKeys()
   val preTimeoutState = Some(5)
+  for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
+    testStateUpdateWithTimeout(
+      s"$timeoutConf - should not timeout",
+      stateUpdates = state => { assert(false, "function called without 
timeout") },
+      timeoutConf = timeoutConf,
+      priorTimeoutTimestamp = afterTimeoutThreshold,
+      expectedState = preTimeoutState,                          // state 
should not change
+      expectedTimeoutTimestamp = afterTimeoutThreshold)         // timestamp 
should not change
+
+    testStateUpdateWithTimeout(
+      s"$timeoutConf - should timeout - no update/remove",
+      stateUpdates = state => { /* do nothing */ },
+      timeoutConf = timeoutConf,
+      priorTimeoutTimestamp = beforeTimeoutThreshold,
+      expectedState = preTimeoutState,                          // state 
should not change
+      expectedTimeoutTimestamp = NO_TIMESTAMP)     // timestamp should be reset
 
-  testStateUpdateWithTimeout(
-    "should not timeout",
-    stateUpdates = state => { assert(false, "function called without timeout") 
},
-    priorTimeoutTimestamp = afterCurrentTimestamp,
-    expectedState = preTimeoutState,                          // state should 
not change
-    expectedTimeoutTimestamp = afterCurrentTimestamp)         // timestamp 
should not change
+    testStateUpdateWithTimeout(
+      s"$timeoutConf - should timeout - update state",
+      stateUpdates = state => { state.update(5) },
+      timeoutConf = timeoutConf,
+      priorTimeoutTimestamp = beforeTimeoutThreshold,
+      expectedState = Some(5),                                  // state 
should change
+      expectedTimeoutTimestamp = NO_TIMESTAMP)     // timestamp should be reset
+
+    testStateUpdateWithTimeout(
+      s"$timeoutConf - should timeout - remove state",
+      stateUpdates = state => { state.remove() },
+      timeoutConf = timeoutConf,
+      priorTimeoutTimestamp = beforeTimeoutThreshold,
+      expectedState = None,                                     // state 
should be removed
+      expectedTimeoutTimestamp = NO_TIMESTAMP)
+  }
 
   testStateUpdateWithTimeout(
-    "should timeout - no update/remove",
-    stateUpdates = state => { /* do nothing */ },
-    priorTimeoutTimestamp = beforeCurrentTimestamp,
+    "ProcessingTimeTimeout - should timeout - timeout duration updated",
+    stateUpdates = state => { state.setTimeoutDuration(2000) },
+    timeoutConf = ProcessingTimeTimeout,
+    priorTimeoutTimestamp = beforeTimeoutThreshold,
     expectedState = preTimeoutState,                          // state should 
not change
-    expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET)     // timestamp 
should be reset
+    expectedTimeoutTimestamp = currentBatchTimestamp + 2000)       // 
timestamp should change
 
   testStateUpdateWithTimeout(
-    "should timeout - update state",
-    stateUpdates = state => { state.update(5) },
-    priorTimeoutTimestamp = beforeCurrentTimestamp,
+    "ProcessingTimeTimeout - should timeout - timeout duration and state 
updated",
+    stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) 
},
+    timeoutConf = ProcessingTimeTimeout,
+    priorTimeoutTimestamp = beforeTimeoutThreshold,
     expectedState = Some(5),                                  // state should 
change
-    expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET)     // timestamp 
should be reset
+    expectedTimeoutTimestamp = currentBatchTimestamp + 2000)  // timestamp 
should change
 
   testStateUpdateWithTimeout(
-    "should timeout - remove state",
-    stateUpdates = state => { state.remove() },
-    priorTimeoutTimestamp = beforeCurrentTimestamp,
-    expectedState = None,                                     // state should 
be removed
-    expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET)
-
-  testStateUpdateWithTimeout(
-    "should timeout - timeout updated",
-    stateUpdates = state => { state.setTimeoutDuration(2000) },
-    priorTimeoutTimestamp = beforeCurrentTimestamp,
+    "EventTimeTimeout - should timeout - timeout timestamp updated",
+    stateUpdates = state => { state.setTimeoutTimestamp(5000) },
+    timeoutConf = EventTimeTimeout,
+    priorTimeoutTimestamp = beforeTimeoutThreshold,
     expectedState = preTimeoutState,                          // state should 
not change
-    expectedTimeoutTimestamp = currentTimestamp + 2000)       // timestamp 
should change
+    expectedTimeoutTimestamp = 5000)                          // timestamp 
should change
 
   testStateUpdateWithTimeout(
-    "should timeout - timeout and state updated",
-    stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) 
},
-    priorTimeoutTimestamp = beforeCurrentTimestamp,
+    "EventTimeTimeout - should timeout - timeout and state updated",
+    stateUpdates = state => { state.update(5); state.setTimeoutTimestamp(5000) 
},
+    timeoutConf = EventTimeTimeout,
+    priorTimeoutTimestamp = beforeTimeoutThreshold,
     expectedState = Some(5),                                  // state should 
change
-    expectedTimeoutTimestamp = currentTimestamp + 2000)       // timestamp 
should change
+    expectedTimeoutTimestamp = 5000)                          // timestamp 
should change
 
   test("StateStoreUpdater - rows are cloned before writing to StateStore") {
     // function for running count
@@ -481,11 +551,10 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
 
     val clock = new StreamManualClock
     val inputData = MemoryStream[String]
-    val timeout = KeyedStateTimeout.ProcessingTimeTimeout
     val result =
       inputData.toDS()
         .groupByKey(x => x)
-        .flatMapGroupsWithState(Update, timeout)(stateFunc)
+        .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
 
     testStream(result, Update)(
       StartStream(ProcessingTime("1 second"), triggerClock = clock),
@@ -519,6 +588,52 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     )
   }
 
+  test("flatMapGroupsWithState - streaming with event time timeout") {
+    // Function to maintain the max event time
+    // Returns the max event time in the state, or -1 if the state was removed 
by timeout
+    val stateFunc = (
+        key: String,
+        values: Iterator[(String, Long)],
+        state: KeyedState[Long]) => {
+      val timeoutDelay = 5
+      if (key != "a") {
+        Iterator.empty
+      } else {
+        if (state.hasTimedOut) {
+          state.remove()
+          Iterator((key, -1))
+        } else {
+          val valuesSeq = values.toSeq
+          val maxEventTime = math.max(valuesSeq.map(_._2).max, 
state.getOption.getOrElse(0L))
+          val timeoutTimestampMs = maxEventTime + timeoutDelay
+          state.update(maxEventTime)
+          state.setTimeoutTimestamp(timeoutTimestampMs * 1000)
+          Iterator((key, maxEventTime.toInt))
+        }
+      }
+    }
+    val inputData = MemoryStream[(String, Int)]
+    val result =
+      inputData.toDS
+        .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Long)]
+        .groupByKey(_._1)
+        .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
+
+    testStream(result, Update)(
+      StartStream(ProcessingTime("1 second")),
+      AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout 
timestamp of ...
+      CheckLastBatch(("a", 15)),                           // "a" to 15 + 5 = 
20s, watermark to 5s
+      AddData(inputData, ("a", 4)),       // Add data older than watermark for 
"a"
+      CheckLastBatch(),                   // No output as data should get 
filtered by watermark
+      AddData(inputData, ("dummy", 35)),  // Set watermark = 35 - 10 = 25s
+      CheckLastBatch(),                   // No output as no data for "a"
+      AddData(inputData, ("a", 24)),      // Add data older than watermark, 
should be ignored
+      CheckLastBatch(("a", -1))           // State for "a" should timeout and 
emit -1
+    )
+  }
+
   test("mapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
@@ -612,7 +727,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     val stateFunc = (key: String, values: Iterator[String], state: 
KeyedState[RunningCount]) => key
     val inputData = MemoryStream[String]
     val result = inputData.toDS.groupByKey(x => 
x).mapGroupsWithState(stateFunc)
-    result
     testStream(result, Update)(
       AddData(inputData, "a"),
       CheckLastBatch("a"),
@@ -649,13 +763,13 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   def testStateUpdateWithData(
       testName: String,
       stateUpdates: KeyedState[Int] => Unit,
-      timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
+      timeoutConf: KeyedStateTimeout,
       priorState: Option[Int],
-      priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET,
+      priorTimeoutTimestamp: Long = NO_TIMESTAMP,
       expectedState: Option[Int] = None,
-      expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
+      expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
 
-    if (priorState.isEmpty && priorTimeoutTimestamp != 
TIMEOUT_TIMESTAMP_NOT_SET) {
+    if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) {
       return // there can be no prior timestamp, when there is no prior state
     }
     test(s"StateStoreUpdater - updates with data - $testName") {
@@ -666,7 +780,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
         Iterator.empty
       }
       testStateUpdate(
-        testTimeoutUpdates = false, mapGroupsFunc, timeoutType,
+        testTimeoutUpdates = false, mapGroupsFunc, timeoutConf,
         priorState, priorTimeoutTimestamp, expectedState, 
expectedTimeoutTimestamp)
     }
   }
@@ -674,9 +788,10 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   def testStateUpdateWithTimeout(
       testName: String,
       stateUpdates: KeyedState[Int] => Unit,
+      timeoutConf: KeyedStateTimeout,
       priorTimeoutTimestamp: Long,
       expectedState: Option[Int],
-      expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
+      expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
 
     test(s"StateStoreUpdater - updates for timeout - $testName") {
       val mapGroupsFunc = (key: Int, values: Iterator[Int], state: 
KeyedState[Int]) => {
@@ -686,16 +801,15 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
         Iterator.empty
       }
       testStateUpdate(
-        testTimeoutUpdates = true, mapGroupsFunc, 
KeyedStateTimeout.ProcessingTimeTimeout,
-        preTimeoutState, priorTimeoutTimestamp,
-        expectedState, expectedTimeoutTimestamp)
+        testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf,
+        preTimeoutState, priorTimeoutTimestamp, expectedState, 
expectedTimeoutTimestamp)
     }
   }
 
   def testStateUpdate(
       testTimeoutUpdates: Boolean,
       mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
-      timeoutType: KeyedStateTimeout,
+      timeoutConf: KeyedStateTimeout,
       priorState: Option[Int],
       priorTimeoutTimestamp: Long,
       expectedState: Option[Int],
@@ -703,7 +817,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
 
     val store = newStateStore()
     val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
-      mapGroupsFunc, timeoutType, currentTimestamp)
+      mapGroupsFunc, timeoutConf, currentBatchTimestamp)
     val updater = new mapGroupsSparkPlan.StateStoreUpdater(store)
     val key = intToRow(0)
     // Prepare store with prior state configs
@@ -736,7 +850,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   def newFlatMapGroupsWithStateExec(
       func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
       timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
-      batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): 
FlatMapGroupsWithStateExec = {
+      batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = {
     MemoryStream[Int]
       .toDS
       .groupByKey(x => x)
@@ -744,11 +858,31 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
       .logicalPlan.collectFirst {
         case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
           FlatMapGroupsWithStateExec(
-            f, k, v, g, d, o, None, s, m, t, currentTimestamp,
-            RDDScanExec(g, null, "rdd"))
+            f, k, v, g, d, o, None, s, m, t,
+            Some(currentBatchTimestamp), Some(currentBatchWatermark), 
RDDScanExec(g, null, "rdd"))
       }.get
   }
 
+  def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: 
KeyedStateImpl[_]): Unit = {
+    val prevTimestamp = state.getTimeoutTimestamp
+    intercept[T] { state.setTimeoutDuration(1000) }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+    intercept[T] { state.setTimeoutDuration("2 second") }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+  }
+
+  def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: 
KeyedStateImpl[_]): Unit = {
+    val prevTimestamp = state.getTimeoutTimestamp
+    intercept[T] { state.setTimeoutTimestamp(2000) }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+    intercept[T] { state.setTimeoutTimestamp(2000, "1 second") }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+    intercept[T] { state.setTimeoutTimestamp(new Date(2000)) }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+    intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") }
+    assert(state.getTimeoutTimestamp === prevTimestamp)
+  }
+
   def newStateStore(): StateStore = new MemoryStateStore()
 
   val intProj = UnsafeProjection.create(Array[DataType](IntegerType))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to