Repository: spark
Updated Branches:
  refs/heads/master 1437e344e -> f3137feec


[SPARK-22278][SS] Expose current event time watermark and current processing 
time in GroupState

## What changes were proposed in this pull request?

Complex state-updating and/or timeout-handling logic in mapGroupsWithState 
functions may require taking decisions based on the current event-time 
watermark and/or processing time. Currently, you can use the SQL function 
`current_timestamp` to get the current processing time, but it needs to be 
passed inserted in every row with a select, and then passed through the 
encoder, which isn't efficient. Furthermore, there is no way to get the current 
watermark.

This PR exposes both of them through the GroupState API.
Additionally, it also cleans up some of the GroupState docs.

## How was this patch tested?

New unit tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #19495 from tdas/SPARK-22278.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f3137fee
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f3137fee
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f3137fee

Branch: refs/heads/master
Commit: f3137feecd30c74c47dbddb0e22b4ddf8cf2f912
Parents: 1437e34
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue Oct 17 20:09:12 2017 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Oct 17 20:09:12 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/execution/objects.scala    |   8 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  |   7 +-
 .../execution/streaming/GroupStateImpl.scala    |  50 +++---
 .../apache/spark/sql/streaming/GroupState.scala |  92 +++++++----
 .../streaming/FlatMapGroupsWithStateSuite.scala | 160 ++++++++++++++++---
 5 files changed, 238 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f3137fee/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index c68975b..d861109 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.plans.logical.{FunctionUtils, 
LogicalGroupState}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
FunctionUtils, LogicalGroupState}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.streaming.GroupStateImpl
 import org.apache.spark.sql.streaming.GroupStateTimeout
@@ -361,8 +361,12 @@ object MapGroupsExec {
       outputObjAttr: Attribute,
       timeoutConf: GroupStateTimeout,
       child: SparkPlan): MapGroupsExec = {
+    val watermarkPresent = child.output.exists {
+      case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => 
true
+      case _ => false
+    }
     val f = (key: Any, values: Iterator[Any]) => {
-      func(key, values, GroupStateImpl.createForBatch(timeoutConf))
+      func(key, values, GroupStateImpl.createForBatch(timeoutConf, 
watermarkPresent))
     }
     new MapGroupsExec(f, keyDeserializer, valueDeserializer,
       groupingAttributes, dataAttributes, outputObjAttr, child)

http://git-wip-us.apache.org/repos/asf/spark/blob/f3137fee/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index c81f1a8..29f38fa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -61,6 +61,10 @@ case class FlatMapGroupsWithStateExec(
 
   private val isTimeoutEnabled = timeoutConf != NoTimeout
   val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, 
isTimeoutEnabled)
+  val watermarkPresent = child.output.exists {
+    case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => 
true
+    case _ => false
+  }
 
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
@@ -190,7 +194,8 @@ case class FlatMapGroupsWithStateExec(
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
-        hasTimedOut)
+        hasTimedOut,
+        watermarkPresent)
 
       // Call function, get the returned objects and convert them to rows
       val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>

http://git-wip-us.apache.org/repos/asf/spark/blob/f3137fee/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index 4401e86..7f65e3e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -43,7 +43,8 @@ private[sql] class GroupStateImpl[S] private(
     batchProcessingTimeMs: Long,
     eventTimeWatermarkMs: Long,
     timeoutConf: GroupStateTimeout,
-    override val hasTimedOut: Boolean) extends GroupState[S] {
+    override val hasTimedOut: Boolean,
+    watermarkPresent: Boolean) extends GroupState[S] {
 
   private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
   private var defined: Boolean = optionalValue.isDefined
@@ -90,7 +91,7 @@ private[sql] class GroupStateImpl[S] private(
     if (timeoutConf != ProcessingTimeTimeout) {
       throw new UnsupportedOperationException(
         "Cannot set timeout duration without enabling processing time timeout 
in " +
-          "map/flatMapGroupsWithState")
+          "[map|flatMap]GroupsWithState")
     }
     if (durationMs <= 0) {
       throw new IllegalArgumentException("Timeout duration must be positive")
@@ -102,10 +103,6 @@ private[sql] class GroupStateImpl[S] private(
     setTimeoutDuration(parseDuration(duration))
   }
 
-  @throws[IllegalArgumentException]("if 'timestampMs' is not positive")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
   override def setTimeoutTimestamp(timestampMs: Long): Unit = {
     checkTimeoutTimestampAllowed()
     if (timestampMs <= 0) {
@@ -119,32 +116,34 @@ private[sql] class GroupStateImpl[S] private(
     timeoutTimestamp = timestampMs
   }
 
-  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
   override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: 
String): Unit = {
     checkTimeoutTimestampAllowed()
     setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs)
   }
 
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
   override def setTimeoutTimestamp(timestamp: Date): Unit = {
     checkTimeoutTimestampAllowed()
     setTimeoutTimestamp(timestamp.getTime)
   }
 
-  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
   override def setTimeoutTimestamp(timestamp: Date, additionalDuration: 
String): Unit = {
     checkTimeoutTimestampAllowed()
     setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration))
   }
 
+  override def getCurrentWatermarkMs(): Long = {
+    if (!watermarkPresent) {
+      throw new UnsupportedOperationException(
+        "Cannot get event time watermark timestamp without setting watermark 
before " +
+          "[map|flatMap]GroupsWithState")
+    }
+    eventTimeWatermarkMs
+  }
+
+  override def getCurrentProcessingTimeMs(): Long = {
+    batchProcessingTimeMs
+  }
+
   override def toString: String = {
     s"GroupState(${getOption.map(_.toString).getOrElse("<undefined>")})"
   }
@@ -187,7 +186,7 @@ private[sql] class GroupStateImpl[S] private(
     if (timeoutConf != EventTimeTimeout) {
       throw new UnsupportedOperationException(
         "Cannot set timeout timestamp without enabling event time timeout in " 
+
-          "map/flatMapGroupsWithState")
+          "[map|flatMapGroupsWithState")
     }
   }
 }
@@ -202,17 +201,22 @@ private[sql] object GroupStateImpl {
       batchProcessingTimeMs: Long,
       eventTimeWatermarkMs: Long,
       timeoutConf: GroupStateTimeout,
-      hasTimedOut: Boolean): GroupStateImpl[S] = {
+      hasTimedOut: Boolean,
+      watermarkPresent: Boolean): GroupStateImpl[S] = {
     new GroupStateImpl[S](
-      optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, 
hasTimedOut)
+      optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs,
+      timeoutConf, hasTimedOut, watermarkPresent)
   }
 
-  def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = {
+  def createForBatch(
+      timeoutConf: GroupStateTimeout,
+      watermarkPresent: Boolean): GroupStateImpl[Any] = {
     new GroupStateImpl[Any](
       optionalValue = None,
-      batchProcessingTimeMs = NO_TIMESTAMP,
+      batchProcessingTimeMs = System.currentTimeMillis,
       eventTimeWatermarkMs = NO_TIMESTAMP,
       timeoutConf,
-      hasTimedOut = false)
+      hasTimedOut = false,
+      watermarkPresent)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f3137fee/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
index 04a956b..e9510c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -205,11 +205,7 @@ trait GroupState[S] extends LogicalGroupState[S] {
   /** Get the state value as a scala Option. */
   def getOption: Option[S]
 
-  /**
-   * Update the value of the state. Note that `null` is not a valid value, and 
it throws
-   * IllegalArgumentException.
-   */
-  @throws[IllegalArgumentException]("when updating with null")
+  /** Update the value of the state. */
   def update(newState: S): Unit
 
   /** Remove this state. */
@@ -217,80 +213,114 @@ trait GroupState[S] extends LogicalGroupState[S] {
 
   /**
    * Whether the function has been called because the key has timed out.
-   * @note This can return true only when timeouts are enabled in 
`[map/flatmap]GroupsWithStates`.
+   * @note This can return true only when timeouts are enabled in 
`[map/flatMap]GroupsWithState`.
    */
   def hasTimedOut: Boolean
 
+
   /**
    * Set the timeout duration in ms for this key.
    *
-   * @note ProcessingTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Processing time timeout]] must be enabled in
+   *       `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no effect when used in a batch query.
    */
   @throws[IllegalArgumentException]("if 'durationMs' is not positive")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
   @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+    "if processing time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutDuration(durationMs: Long): Unit
 
+
   /**
    * Set the timeout duration for this key as a string. For example, "1 hour", 
"2 days", etc.
    *
-   * @note ProcessingTimeTimeout must be enabled in 
`[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Processing time timeout]] must be enabled in
+   *       `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no effect when used in a batch query.
    */
   @throws[IllegalArgumentException]("if 'duration' is not a valid duration")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
   @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+    "if processing time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutDuration(duration: String): Unit
 
-  @throws[IllegalArgumentException]("if 'timestampMs' is not positive")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+
   /**
    * Set the timeout timestamp for this key as milliseconds in epoch time.
    * This timestamp cannot be older than the current watermark.
    *
-   * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Event time timeout]] must be enabled in
+   *       `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no effect when used in a batch query.
    */
+  @throws[IllegalArgumentException](
+    "if 'timestampMs' is not positive or less than the current watermark in a 
streaming query")
+  @throws[UnsupportedOperationException](
+    "if processing time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutTimestamp(timestampMs: Long): Unit
 
-  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+
   /**
    * Set the timeout timestamp for this key as milliseconds in epoch time and 
an additional
    * duration as a string (e.g. "1 hour", "2 days", etc.).
    * The final timestamp (including the additional duration) cannot be older 
than the
    * current watermark.
    *
-   * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Event time timeout]] must be enabled in
+   *       `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no side effect when used in a batch query.
    */
+  @throws[IllegalArgumentException](
+    "if 'additionalDuration' is invalid or the final timeout timestamp is less 
than " +
+      "the current watermark in a streaming query")
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit
 
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+
   /**
    * Set the timeout timestamp for this key as a java.sql.Date.
    * This timestamp cannot be older than the current watermark.
    *
-   * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Event time timeout]] must be enabled in
+   *       `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no side effect when used in a batch query.
    */
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutTimestamp(timestamp: java.sql.Date): Unit
 
-  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
-  @throws[IllegalStateException]("when state is either not initialized, or 
already removed")
-  @throws[UnsupportedOperationException](
-    "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a 
streaming query")
+
   /**
    * Set the timeout timestamp for this key as a java.sql.Date and an 
additional
    * duration as a string (e.g. "1 hour", "2 days", etc.).
    * The final timestamp (including the additional duration) cannot be older 
than the
    * current watermark.
    *
-   * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`.
+   * @note [[GroupStateTimeout Event time timeout]] must be enabled in
+   *      `[map/flatMap]GroupsWithState` for calling this method.
+   * @note This method has no side effect when used in a batch query.
    */
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in 
[map|flatMap]GroupsWithState")
   def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: 
String): Unit
+
+
+  /**
+   * Get the current event time watermark as milliseconds in epoch time.
+   *
+   * @note In a streaming query, this can be called only when watermark is set 
before calling
+   *       `[map/flatMap]GroupsWithState`. In a batch query, this method 
always returns -1.
+   */
+  @throws[UnsupportedOperationException](
+    "if watermark has not been set before in [map|flatMap]GroupsWithState")
+  def getCurrentWatermarkMs(): Long
+
+
+  /**
+   * Get the current processing time as milliseconds in epoch time.
+   * @note In a streaming query, this will return a constant value throughout 
the duration of a
+   *       trigger, even if the trigger is re-executed.
+   */
+  def getCurrentProcessingTimeMs(): Long
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f3137fee/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index aeb8383..af08186 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -21,6 +21,7 @@ import java.sql.Date
 import java.util.concurrent.ConcurrentHashMap
 
 import org.scalatest.BeforeAndAfterAll
+import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.SparkException
 import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
@@ -48,6 +49,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   import testImplicits._
   import GroupStateImpl._
   import GroupStateTimeout._
+  import FlatMapGroupsWithStateSuite._
 
   override def afterAll(): Unit = {
     super.afterAll()
@@ -77,13 +79,15 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
 
     // === Tests for state in streaming queries ===
     // Updating empty state
-    state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, 
hasTimedOut = false)
+    state = GroupStateImpl.createForStreaming(
+      None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false)
     testState(None)
     state.update("")
     testState(Some(""), shouldBeUpdated = true)
 
     // Updating exiting state
-    state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, 
hasTimedOut = false)
+    state = GroupStateImpl.createForStreaming(
+      Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
     testState(Some("2"))
     state.update("3")
     testState(Some("3"), shouldBeUpdated = true)
@@ -104,8 +108,9 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   test("GroupState - setTimeout - with NoTimeout") {
     for (initValue <- Seq(None, Some(5))) {
       val states = Seq(
-        GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, 
hasTimedOut = false),
-        GroupStateImpl.createForBatch(NoTimeout)
+        GroupStateImpl.createForStreaming(
+          initValue, 1000, 1000, NoTimeout, hasTimedOut = false, 
watermarkPresent = false),
+        GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false)
       )
       for (state <- states) {
         // for streaming queries
@@ -122,7 +127,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   test("GroupState - setTimeout - with ProcessingTimeTimeout") {
     // for streaming queries
     var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
-      None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
+      None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     state.setTimeoutDuration(500)
     assert(state.getTimeoutTimestamp === 1500) // can be set without 
initializing state
@@ -143,7 +148,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     // for batch queries
-    state = 
GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
+    state = GroupStateImpl.createForBatch(
+      ProcessingTimeTimeout, watermarkPresent = 
false).asInstanceOf[GroupStateImpl[Int]]
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     state.setTimeoutDuration(500)
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
@@ -160,7 +166,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
 
   test("GroupState - setTimeout - with EventTimeTimeout") {
     var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
-      None, 1000, 1000, EventTimeTimeout, false)
+      None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true)
 
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
@@ -182,7 +188,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
 
     // for batch queries
-    state = 
GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
+    state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = 
false)
+      .asInstanceOf[GroupStateImpl[Int]]
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
     state.setTimeoutTimestamp(5000)
@@ -209,7 +216,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     }
 
     state = GroupStateImpl.createForStreaming(
-      Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
+      Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
     testIllegalTimeout {
       state.setTimeoutDuration(-1000)
     }
@@ -227,7 +234,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     }
 
     state = GroupStateImpl.createForStreaming(
-      Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
+      Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
     testIllegalTimeout {
       state.setTimeoutTimestamp(-10000)
     }
@@ -259,29 +266,92 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       // for streaming queries
       for (initState <- Seq(None, Some(5))) {
         val state1 = GroupStateImpl.createForStreaming(
-          initState, 1000, 1000, timeoutConf, hasTimedOut = false)
+          initState, 1000, 1000, timeoutConf, hasTimedOut = false, 
watermarkPresent = false)
         assert(state1.hasTimedOut === false)
 
         val state2 = GroupStateImpl.createForStreaming(
-          initState, 1000, 1000, timeoutConf, hasTimedOut = true)
+          initState, 1000, 1000, timeoutConf, hasTimedOut = true, 
watermarkPresent = false)
         assert(state2.hasTimedOut === true)
       }
 
       // for batch queries
-      assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false)
+      assert(
+        GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = 
false).hasTimedOut === false)
+    }
+  }
+
+  test("GroupState - getCurrentWatermarkMs") {
+    def streamingState(timeoutConf: GroupStateTimeout, watermark: 
Option[Long]): GroupState[Int] = {
+      GroupStateImpl.createForStreaming(
+        None, 1000, watermark.getOrElse(-1), timeoutConf,
+        hasTimedOut = false, watermark.nonEmpty)
+    }
+
+    def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): 
GroupState[Any] = {
+      GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)
+    }
+
+    def assertWrongTimeoutError(test: => Unit): Unit = {
+      val e = intercept[UnsupportedOperationException] { test }
+      assert(e.getMessage.contains(
+        "Cannot get event time watermark timestamp without setting watermark"))
+    }
+
+    for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, 
ProcessingTimeTimeout)) {
+      // Tests for getCurrentWatermarkMs in streaming queries
+      assertWrongTimeoutError { streamingState(timeoutConf, 
None).getCurrentWatermarkMs() }
+      assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() 
=== 1000)
+      assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() 
=== 2000)
+
+      // Tests for getCurrentWatermarkMs in batch queries
+      assertWrongTimeoutError {
+        batchState(timeoutConf, watermarkPresent = 
false).getCurrentWatermarkMs()
+      }
+      assert(batchState(timeoutConf, watermarkPresent = 
true).getCurrentWatermarkMs() === -1)
+    }
+  }
+
+  test("GroupState - getCurrentProcessingTimeMs") {
+    def streamingState(
+        timeoutConf: GroupStateTimeout,
+        procTime: Long,
+        watermarkPresent: Boolean): GroupState[Int] = {
+      GroupStateImpl.createForStreaming(
+        None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent 
= false)
+    }
+
+    def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): 
GroupState[Any] = {
+      GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)
+    }
+
+    for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, 
ProcessingTimeTimeout)) {
+      for (watermarkPresent <- Seq(false, true)) {
+        // Tests for getCurrentProcessingTimeMs in streaming queries
+        assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent)
+            .getCurrentProcessingTimeMs() === -1)
+        assert(streamingState(timeoutConf, 1000, watermarkPresent)
+          .getCurrentProcessingTimeMs() === 1000)
+        assert(streamingState(timeoutConf, 2000, watermarkPresent)
+          .getCurrentProcessingTimeMs() === 2000)
+
+        // Tests for getCurrentProcessingTimeMs in batch queries
+        val currentTime = System.currentTimeMillis()
+        assert(batchState(timeoutConf, 
watermarkPresent).getCurrentProcessingTimeMs >= currentTime)
+      }
     }
   }
 
+
   test("GroupState - primitive type") {
     var intState = GroupStateImpl.createForStreaming[Int](
-      None, 1000, 1000, NoTimeout, hasTimedOut = false)
+      None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
     intercept[NoSuchElementException] {
       intState.get
     }
     assert(intState.getOption === None)
 
     intState = GroupStateImpl.createForStreaming[Int](
-      Some(10), 1000, 1000, NoTimeout, hasTimedOut = false)
+      Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
     assert(intState.get == 10)
     intState.update(0)
     assert(intState.get == 0)
@@ -304,7 +374,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
 
     testStateUpdateWithData(
       testName + "no update",
-      stateUpdates = state => { /* do nothing */ },
+      stateUpdates = state => {
+        assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp)
+        intercept[Exception] { state.getCurrentWatermarkMs() } // watermark 
not specified
+        /* no updates */
+      },
       timeoutConf = GroupStateTimeout.NoTimeout,
       priorState = priorState,
       expectedState = priorState)    // should not change
@@ -342,7 +416,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
 
         testStateUpdateWithData(
           s"$timeoutConf - $testName - no update",
-          stateUpdates = state => { /* do nothing */ },
+          stateUpdates = state => {
+            assert(state.getCurrentProcessingTimeMs() === 
currentBatchTimestamp)
+            intercept[Exception] { state.getCurrentWatermarkMs() } // 
watermark not specified
+            /* no updates */
+          },
           timeoutConf = timeoutConf,
           priorState = priorState,
           priorTimeoutTimestamp = priorTimeoutTimestamp,
@@ -466,7 +544,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
 
     testStateUpdateWithTimeout(
       s"$timeoutConf - should timeout - no update/remove",
-      stateUpdates = state => { /* do nothing */ },
+      stateUpdates = state => {
+        assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp)
+        intercept[Exception] { state.getCurrentWatermarkMs() } // watermark 
not specified
+        /* no updates */
+      },
       timeoutConf = timeoutConf,
       priorTimeoutTimestamp = beforeTimeoutThreshold,
       expectedState = preTimeoutState,                          // state 
should not change
@@ -525,6 +607,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count if state is defined, otherwise does not 
return anything
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
 
       val count = state.getOption.map(_.count).getOrElse(0L) + values.size
       if (count == 3) {
@@ -647,6 +731,9 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   test("flatMapGroupsWithState - batch") {
     // Function that returns running count only if its even, otherwise does 
not return
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+
       if (state.exists) throw new IllegalArgumentException("state.exists 
should be false")
       Iterator((key, values.size))
     }
@@ -660,6 +747,9 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+
       if (state.hasTimedOut) {
         state.remove()
         Iterator((key, "-1"))
@@ -713,10 +803,10 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   test("flatMapGroupsWithState - streaming with event time timeout + 
watermark") {
     // Function to maintain the max event time
     // Returns the max event time in the state, or -1 if the state was removed 
by timeout
-    val stateFunc = (
-        key: String,
-        values: Iterator[(String, Long)],
-        state: GroupState[Long]) => {
+    val stateFunc = (key: String, values: Iterator[(String, Long)], state: 
GroupState[Long]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 }
+
       val timeoutDelay = 5
       if (key != "a") {
         Iterator.empty
@@ -760,6 +850,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
 
       val count = state.getOption.map(_.count).getOrElse(0L) + values.size
       if (count == 3) {
@@ -802,7 +894,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     // - no initial state
     // - timeouts operations work, does not throw any error [SPARK-20792]
     // - works with primitive state type
+    // - can get processing time
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[Int]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 }
+      assertCannotGetWatermark { state.getCurrentWatermarkMs() }
+
       if (state.exists) throw new IllegalArgumentException("state.exists 
should be false")
       state.setTimeoutTimestamp(0, "1 hour")
       state.update(10)
@@ -1090,4 +1186,24 @@ object FlatMapGroupsWithStateSuite {
     override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 
0, Map.empty)
     override def hasCommitted: Boolean = true
   }
+
+  def assertCanGetProcessingTime(predicate: => Boolean): Unit = {
+    if (!predicate) throw new TestFailedException("Could not get processing 
time", 20)
+  }
+
+  def assertCanGetWatermark(predicate: => Boolean): Unit = {
+    if (!predicate) throw new TestFailedException("Could not get processing 
time", 20)
+  }
+
+  def assertCannotGetWatermark(func: => Unit): Unit = {
+    try {
+      func
+    } catch {
+      case u: UnsupportedOperationException =>
+        return
+      case _ =>
+        throw new TestFailedException("Unexpected exception when trying to get 
watermark", 20)
+    }
+    throw new TestFailedException("Could get watermark when not expected", 20)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to