Repository: spark
Updated Branches:
  refs/heads/branch-2.2 3aad5982a -> cfd1bf0be


[SPARK-20792][SS] Support same timeout operations in mapGroupsWithState 
function in batch queries as in streaming queries

## What changes were proposed in this pull request?

Currently, in the batch queries, timeout is disabled (i.e. 
GroupStateTimeout.NoTimeout) which means any GroupState.setTimeout*** operation 
would throw UnsupportedOperationException. This makes it weird when converting 
a streaming query into a batch query by changing the input DF from streaming to 
a batch DF. If the timeout was enabled and used, then the batch query will 
start throwing UnsupportedOperationException.

This PR creates the dummy state in batch queries with the provided timeoutConf 
so that it behaves in the same way. The code has been refactored to make it 
obvious when the state is being created for a batch query or a streaming query.

## How was this patch tested?
Additional tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #18024 from tdas/SPARK-20792.

(cherry picked from commit 9d6661c829a4a82aae64ed0522c44e4c3d8f4f0b)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cfd1bf0b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cfd1bf0b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cfd1bf0b

Branch: refs/heads/branch-2.2
Commit: cfd1bf0bef766a9b13fe16bcca172d4108eb4e56
Parents: 3aad598
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Sun May 21 13:07:25 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Sun May 21 13:07:32 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |   5 +-
 .../apache/spark/sql/execution/objects.scala    |   6 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  |   2 +-
 .../execution/streaming/GroupStateImpl.scala    |  42 +++----
 .../streaming/FlatMapGroupsWithStateSuite.scala | 113 ++++++++++++++-----
 5 files changed, 116 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cfd1bf0b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ca2f6dd..73541c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -383,8 +383,9 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
         execution.MapGroupsExec(f, key, value, grouping, data, objAttr, 
planLater(child)) :: Nil
       case logical.FlatMapGroupsWithState(
-          f, key, value, grouping, data, output, _, _, _, _, child) =>
-        execution.MapGroupsExec(f, key, value, grouping, data, output, 
planLater(child)) :: Nil
+          f, key, value, grouping, data, output, _, _, _, timeout, child) =>
+        execution.MapGroupsExec(
+          f, key, value, grouping, data, output, timeout, planLater(child)) :: 
Nil
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, 
oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,

http://git-wip-us.apache.org/repos/asf/spark/blob/cfd1bf0b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 48c7b80..3439181 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
 import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.streaming.GroupStateTimeout
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -361,8 +362,11 @@ object MapGroupsExec {
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
       outputObjAttr: Attribute,
+      timeoutConf: GroupStateTimeout,
       child: SparkPlan): MapGroupsExec = {
-    val f = (key: Any, values: Iterator[Any]) => func(key, values, new 
GroupStateImpl[Any](None))
+    val f = (key: Any, values: Iterator[Any]) => {
+      func(key, values, GroupStateImpl.createForBatch(timeoutConf))
+    }
     new MapGroupsExec(f, keyDeserializer, valueDeserializer,
       groupingAttributes, dataAttributes, outputObjAttr, child)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/cfd1bf0b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index bd8d5d7..3ceb4cf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec(
       val keyObj = getKeyObj(keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value 
rows to objects
       val stateObjOption = getStateObj(prevStateRowOption)
-      val keyedState = new GroupStateImpl(
+      val keyedState = GroupStateImpl.createForStreaming(
         stateObjOption,
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),

http://git-wip-us.apache.org/repos/asf/spark/blob/cfd1bf0b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index d4606fd5..4401e86 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval
  * @param hasTimedOut     Whether the key for which this state wrapped is 
being created is
  *                        getting timed out or not.
  */
-private[sql] class GroupStateImpl[S](
+private[sql] class GroupStateImpl[S] private(
     optionalValue: Option[S],
     batchProcessingTimeMs: Long,
     eventTimeWatermarkMs: Long,
     timeoutConf: GroupStateTimeout,
     override val hasTimedOut: Boolean) extends GroupState[S] {
 
-  // Constructor to create dummy state when using mapGroupsWithState in a 
batch query
-  def this(optionalValue: Option[S]) = this(
-    optionalValue,
-    batchProcessingTimeMs = NO_TIMESTAMP,
-    eventTimeWatermarkMs = NO_TIMESTAMP,
-    timeoutConf = GroupStateTimeout.NoTimeout,
-    hasTimedOut = false)
   private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
   private var defined: Boolean = optionalValue.isDefined
   private var updated: Boolean = false // whether value has been updated (but 
not removed)
@@ -102,12 +95,7 @@ private[sql] class GroupStateImpl[S](
     if (durationMs <= 0) {
       throw new IllegalArgumentException("Timeout duration must be positive")
     }
-    if (batchProcessingTimeMs != NO_TIMESTAMP) {
-      timeoutTimestamp = durationMs + batchProcessingTimeMs
-    } else {
-      // This is being called in a batch query, hence no processing timestamp.
-      // Just ignore any attempts to set timeout.
-    }
+    timeoutTimestamp = durationMs + batchProcessingTimeMs
   }
 
   override def setTimeoutDuration(duration: String): Unit = {
@@ -128,12 +116,7 @@ private[sql] class GroupStateImpl[S](
         s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
           s"current watermark ($eventTimeWatermarkMs)")
     }
-    if (batchProcessingTimeMs != NO_TIMESTAMP) {
-      timeoutTimestamp = timestampMs
-    } else {
-      // This is being called in a batch query, hence no processing timestamp.
-      // Just ignore any attempts to set timeout.
-    }
+    timeoutTimestamp = timestampMs
   }
 
   @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
@@ -213,4 +196,23 @@ private[sql] class GroupStateImpl[S](
 private[sql] object GroupStateImpl {
   // Value used represent the lack of valid timestamp as a long
   val NO_TIMESTAMP = -1L
+
+  def createForStreaming[S](
+      optionalValue: Option[S],
+      batchProcessingTimeMs: Long,
+      eventTimeWatermarkMs: Long,
+      timeoutConf: GroupStateTimeout,
+      hasTimedOut: Boolean): GroupStateImpl[S] = {
+    new GroupStateImpl[S](
+      optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, 
hasTimedOut)
+  }
+
+  def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = {
+    new GroupStateImpl[Any](
+      optionalValue = None,
+      batchProcessingTimeMs = NO_TIMESTAMP,
+      eventTimeWatermarkMs = NO_TIMESTAMP,
+      timeoutConf,
+      hasTimedOut = false)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cfd1bf0b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 10e9174..6bb9408 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
       assert(state.hasRemoved === shouldBeRemoved)
     }
 
+    // === Tests for state in streaming queries ===
     // Updating empty state
-    state = new GroupStateImpl[String](None)
+    state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, 
hasTimedOut = false)
     testState(None)
     state.update("")
     testState(Some(""), shouldBeUpdated = true)
 
     // Updating exiting state
-    state = new GroupStateImpl[String](Some("2"))
+    state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, 
hasTimedOut = false)
     testState(Some("2"))
     state.update("3")
     testState(Some("3"), shouldBeUpdated = true)
@@ -99,25 +100,34 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   }
 
   test("GroupState - setTimeout**** with NoTimeout") {
-    for (initState <- Seq(None, Some(5))) {
-      // for different initial state
-      implicit val state = new GroupStateImpl(initState, 1000, 1000, 
NoTimeout, hasTimedOut = false)
-      testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
-      testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+    for (initValue <- Seq(None, Some(5))) {
+      val states = Seq(
+        GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, 
hasTimedOut = false),
+        GroupStateImpl.createForBatch(NoTimeout)
+      )
+      for (state <- states) {
+        // for streaming queries
+        testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+        testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+        // for batch queries
+        testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+        testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+      }
     }
   }
 
   test("GroupState - setTimeout**** with ProcessingTimeTimeout") {
-    implicit var state: GroupStateImpl[Int] = null
-
-    state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, 
hasTimedOut = false)
+    // for streaming queries
+    var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
+      None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     state.setTimeoutDuration(500)
-    assert(state.getTimeoutTimestamp === 1500)    // can be set without 
initializing state
+    assert(state.getTimeoutTimestamp === 1500) // can be set without 
initializing state
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     state.update(5)
-    assert(state.getTimeoutTimestamp === 1500)    // does not change
+    assert(state.getTimeoutTimestamp === 1500) // does not change
     state.setTimeoutDuration(1000)
     assert(state.getTimeoutTimestamp === 2000)
     state.setTimeoutDuration("2 second")
@@ -125,22 +135,38 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     state.remove()
-    assert(state.getTimeoutTimestamp === 3000)    // does not change
-    state.setTimeoutDuration(500)                 // can still be set
+    assert(state.getTimeoutTimestamp === 3000) // does not change
+    state.setTimeoutDuration(500) // can still be set
     assert(state.getTimeoutTimestamp === 1500)
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+    // for batch queries
+    state = 
GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    state.setTimeoutDuration(500)
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+    state.update(5)
+    state.setTimeoutDuration(1000)
+    state.setTimeoutDuration("2 second")
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+    state.remove()
+    state.setTimeoutDuration(500)
+    testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
   }
 
   test("GroupState - setTimeout**** with EventTimeTimeout") {
-    implicit val state = new GroupStateImpl[Int](
-      None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
+    var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
+      None, 1000, 1000, EventTimeTimeout, false)
+
     assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
     state.setTimeoutTimestamp(5000)
-    assert(state.getTimeoutTimestamp === 5000)    // can be set without 
initializing state
+    assert(state.getTimeoutTimestamp === 5000) // can be set without 
initializing state
 
     state.update(5)
-    assert(state.getTimeoutTimestamp === 5000)    // does not change
+    assert(state.getTimeoutTimestamp === 5000) // does not change
     state.setTimeoutTimestamp(10000)
     assert(state.getTimeoutTimestamp === 10000)
     state.setTimeoutTimestamp(new Date(20000))
@@ -150,7 +176,22 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     state.remove()
     assert(state.getTimeoutTimestamp === 20000)
     state.setTimeoutTimestamp(5000)
-    assert(state.getTimeoutTimestamp === 5000)    // can be set after removing 
state
+    assert(state.getTimeoutTimestamp === 5000) // can be set after removing 
state
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+
+    // for batch queries
+    state = 
GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
+    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+    state.setTimeoutTimestamp(5000)
+
+    state.update(5)
+    state.setTimeoutTimestamp(10000)
+    state.setTimeoutTimestamp(new Date(20000))
+    testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+
+    state.remove()
+    state.setTimeoutTimestamp(5000)
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
   }
 
@@ -165,7 +206,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
       assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
     }
 
-    state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, 
hasTimedOut = false)
+    state = GroupStateImpl.createForStreaming(
+      Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
     testIllegalTimeout {
       state.setTimeoutDuration(-1000)
     }
@@ -182,7 +224,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
       state.setTimeoutDuration("1 month -1 day")
     }
 
-    state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, 
hasTimedOut = false)
+    state = GroupStateImpl.createForStreaming(
+      Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
     testIllegalTimeout {
       state.setTimeoutTimestamp(-10000)
     }
@@ -211,23 +254,32 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
 
   test("GroupState - hasTimedOut") {
     for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, 
EventTimeTimeout)) {
+      // for streaming queries
       for (initState <- Seq(None, Some(5))) {
-        val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, 
hasTimedOut = false)
+        val state1 = GroupStateImpl.createForStreaming(
+          initState, 1000, 1000, timeoutConf, hasTimedOut = false)
         assert(state1.hasTimedOut === false)
-        val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, 
hasTimedOut = true)
+
+        val state2 = GroupStateImpl.createForStreaming(
+          initState, 1000, 1000, timeoutConf, hasTimedOut = true)
         assert(state2.hasTimedOut === true)
       }
+
+      // for batch queries
+      assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false)
     }
   }
 
   test("GroupState - primitive type") {
-    var intState = new GroupStateImpl[Int](None)
+    var intState = GroupStateImpl.createForStreaming[Int](
+      None, 1000, 1000, NoTimeout, hasTimedOut = false)
     intercept[NoSuchElementException] {
       intState.get
     }
     assert(intState.getOption === None)
 
-    intState = new GroupStateImpl[Int](Some(10))
+    intState = GroupStateImpl.createForStreaming[Int](
+      Some(10), 1000, 1000, NoTimeout, hasTimedOut = false)
     assert(intState.get == 10)
     intState.update(0)
     assert(intState.get == 0)
@@ -243,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   val beforeTimeoutThreshold = 999
   val afterTimeoutThreshold = 1001
 
-
   // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = 
NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no 
prior state"
@@ -748,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
   }
 
   test("mapGroupsWithState - batch") {
-    val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+    // Test the following
+    // - no initial state
+    // - timeouts operations work, does not throw any error [SPARK-20792]
+    // - works with primitive state type
+    val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[Int]) => {
       if (state.exists) throw new IllegalArgumentException("state.exists 
should be false")
+      state.setTimeoutTimestamp(0, "1 hour")
+      state.update(10)
       (key, values.size)
     }
 
     checkAnswer(
       spark.createDataset(Seq("a", "a", "b"))
         .groupByKey(x => x)
-        .mapGroupsWithState(stateFunc)
+        .mapGroupsWithState(EventTimeTimeout)(stateFunc)
         .toDF,
       spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to