Repository: spark
Updated Branches:
  refs/heads/master 7786f9cc0 -> e4e46b20f


[SPARK-11681][STREAMING] Correctly update state timestamp even when state is 
not updated

Bug: Timestamp is not updated if there is data but the corresponding state is 
not updated. This is wrong, and timeout is defined as "no data for a while", 
not "not state update for a while".

Fix: Update timestamp when timestamp when timeout is specified, otherwise no 
need.
Also refactored the code for better testability and added unit tests.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #9648 from tdas/SPARK-11681.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e4e46b20
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e4e46b20
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e4e46b20

Branch: refs/heads/master
Commit: e4e46b20f6475f8e148d5326f7c88c57850d46a1
Parents: 7786f9c
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Thu Nov 12 19:02:49 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Thu Nov 12 19:02:49 2015 -0800

----------------------------------------------------------------------
 .../spark/streaming/rdd/TrackStateRDD.scala     | 105 ++++++++------
 .../streaming/rdd/TrackStateRDDSuite.scala      | 136 ++++++++++++++++++-
 2 files changed, 192 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e4e46b20/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index fc51496..7050378 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -32,8 +32,51 @@ import org.apache.spark._
  * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a 
[[StateMap]] and a
  * sequence of records returned by the tracking function of `trackStateByKey`.
  */
-private[streaming] case class TrackStateRDDRecord[K, S, T](
-    var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+private[streaming] case class TrackStateRDDRecord[K, S, E](
+    var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+
+private[streaming] object TrackStateRDDRecord {
+  def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+    dataIterator: Iterator[(K, V)],
+    updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+    batchTime: Time,
+    timeoutThresholdTime: Option[Long],
+    removeTimedoutData: Boolean
+  ): TrackStateRDDRecord[K, S, E] = {
+    // Create a new state map by cloning the previous one (if it exists) or by 
creating an empty one
+    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new 
EmptyStateMap[K, S]() }
+
+    val emittedRecords = new ArrayBuffer[E]
+    val wrappedState = new StateImpl[S]()
+
+    // Call the tracking function on each record in the data iterator, and 
accordingly
+    // update the states touched, and collect the data returned by the 
tracking function
+    dataIterator.foreach { case (key, value) =>
+      wrappedState.wrap(newStateMap.get(key))
+      val emittedRecord = updateFunction(batchTime, key, Some(value), 
wrappedState)
+      if (wrappedState.isRemoved) {
+        newStateMap.remove(key)
+      } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
+        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+      }
+      emittedRecords ++= emittedRecord
+    }
+
+    // Get the timed out state records, call the tracking function on each and 
collect the
+    // data returned
+    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
+      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, 
state, _) =>
+        wrappedState.wrapTiminoutState(state)
+        val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
+        emittedRecords ++= emittedRecord
+        newStateMap.remove(key)
+      }
+    }
+
+    TrackStateRDDRecord(newStateMap, emittedRecords)
+  }
+}
 
 /**
  * Partition of the [[TrackStateRDD]], which depends on corresponding 
partitions of prev state
@@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
  * @param batchTime        The time of the batch to which this RDD belongs to. 
Use to update
  * @param timeoutThresholdTime The time to indicate which keys are timeout
  */
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, 
T: ClassTag](
-    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, 
E: ClassTag](
+    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
     private var partitionedDataRDD: RDD[(K, V)],
-    trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+    trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
     batchTime: Time,
     timeoutThresholdTime: Option[Long]
-  ) extends RDD[TrackStateRDDRecord[K, S, T]](
+  ) extends RDD[TrackStateRDDRecord[K, S, E]](
     partitionedDataRDD.sparkContext,
     List(
-      new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+      new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
       new OneToOneDependency(partitionedDataRDD))
   ) {
 
@@ -98,7 +141,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: 
ClassTag, S: ClassTag, T:
   }
 
   override def compute(
-      partition: Partition, context: TaskContext): 
Iterator[TrackStateRDDRecord[K, S, T]] = {
+      partition: Partition, context: TaskContext): 
Iterator[TrackStateRDDRecord[K, S, E]] = {
 
     val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
     val prevStateRDDIterator = prevStateRDD.iterator(
@@ -106,42 +149,16 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: 
ClassTag, S: ClassTag, T:
     val dataIterator = partitionedDataRDD.iterator(
       stateRDDPartition.partitionedDataRDDPartition, context)
 
-    // Create a new state map by cloning the previous one (if it exists) or by 
creating an empty one
-    val newStateMap = if (prevStateRDDIterator.hasNext) {
-      prevStateRDDIterator.next().stateMap.copy()
-    } else {
-      new EmptyStateMap[K, S]()
-    }
-
-    val emittedRecords = new ArrayBuffer[T]
-    val wrappedState = new StateImpl[S]()
-
-    // Call the tracking function on each record in the data RDD partition, 
and accordingly
-    // update the states touched, and the data returned by the tracking 
function.
-    dataIterator.foreach { case (key, value) =>
-      wrappedState.wrap(newStateMap.get(key))
-      val emittedRecord = trackingFunction(batchTime, key, Some(value), 
wrappedState)
-      if (wrappedState.isRemoved) {
-        newStateMap.remove(key)
-      } else if (wrappedState.isUpdated) {
-        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
-      }
-      emittedRecords ++= emittedRecord
-    }
-
-    // If the RDD is expected to be doing a full scan of all the data in the 
StateMap,
-    // then use this opportunity to filter out those keys that have timed out.
-    // For each of them call the tracking function.
-    if (doFullScan && timeoutThresholdTime.isDefined) {
-      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, 
state, _) =>
-        wrappedState.wrapTiminoutState(state)
-        val emittedRecord = trackingFunction(batchTime, key, None, 
wrappedState)
-        emittedRecords ++= emittedRecord
-        newStateMap.remove(key)
-      }
-    }
-
-    Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+    val prevRecord = if (prevStateRDDIterator.hasNext) 
Some(prevStateRDDIterator.next()) else None
+    val newRecord = TrackStateRDDRecord.updateRecordWithData(
+      prevRecord,
+      dataIterator,
+      trackingFunction,
+      batchTime,
+      timeoutThresholdTime,
+      removeTimedoutData = doFullScan // remove timedout data only when full 
scan is enabled
+    )
+    Iterator(newRecord)
   }
 
   override protected def getPartitions: Array[Partition] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/e4e46b20/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index f396b76..19ef5a1 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
 import org.apache.spark.streaming.{Time, State}
 import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, 
SparkFunSuite}
 
@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     assert(rdd.partitioner === Some(partitioner))
   }
 
+  test("updating state and generating emitted data in TrackStateRecord") {
+
+    val initialTime = 1000L
+    val updatedTime = 2000L
+    val thresholdTime = 1500L
+    @volatile var functionCalled = false
+
+    /**
+     * Assert that applying given data on a prior record generates correct 
updated record, with
+     * correct state map and emitted data
+     */
+    def assertRecordUpdate(
+        initStates: Iterable[Int],
+        data: Iterable[String],
+        expectedStates: Iterable[(Int, Long)],
+        timeoutThreshold: Option[Long] = None,
+        removeTimedoutData: Boolean = false,
+        expectedOutput: Iterable[Int] = None,
+        expectedTimingOutStates: Iterable[Int] = None,
+        expectedRemovedStates: Iterable[Int] = None
+      ): Unit = {
+      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
+      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
+      functionCalled = false
+      val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, 
Seq.empty)
+      val dataIterator = data.map { v => ("key", v) }.iterator
+      val removedStates = new ArrayBuffer[Int]
+      val timingOutStates = new ArrayBuffer[Int]
+      /**
+       * Tracking function that updates/removes state based on instructions in 
the data, and
+       * return state (when instructed or when state is timing out).
+       */
+      def testFunc(t: Time, key: String, data: Option[String], state: 
State[Int]): Option[Int] = {
+        functionCalled = true
+
+        assert(t.milliseconds === updatedTime, "tracking func called with 
wrong time")
+
+        data match {
+          case Some("noop") =>
+            None
+          case Some("get-state") =>
+            Some(state.getOption().getOrElse(-1))
+          case Some("update-state") =>
+            if (state.exists) state.update(state.get + 1) else state.update(0)
+            None
+          case Some("remove-state") =>
+            removedStates += state.get()
+            state.remove()
+            None
+          case None =>
+            assert(state.isTimingOut() === true, "State is not timing out when 
data = None")
+            timingOutStates += state.get()
+            None
+          case _ =>
+            fail("Unexpected test data")
+        }
+      }
+
+      val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, 
String, Int, Int](
+        Some(record), dataIterator, testFunc,
+        Time(updatedTime), timeoutThreshold, removeTimedoutData)
+
+      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, 
x._3) }
+      assert(updatedStateData.toSet === expectedStates.toSet,
+        "states do not match after updating the TrackStateRecord")
+
+      assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
+        "emitted data do not match after updating the TrackStateRecord")
+
+      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing 
out states do not " +
+        "match those that were expected to do so while updating the 
TrackStateRecord")
+
+      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed 
states do not " +
+        "match those that were expected to do so while updating the 
TrackStateRecord")
+
+    }
+
+    // No data, no state should be changed, function should not be called,
+    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
+    assert(functionCalled === false)
+    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = 
Seq((0, initialTime)))
+    assert(functionCalled === false)
+
+    // Data present, function should be called irrespective of whether state 
exists
+    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
+      expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === true)
+    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates 
= None)
+    assert(functionCalled === true)
+
+    // Function called with right state data
+    assertRecordUpdate(initStates = None, data = Seq("get-state"),
+      expectedStates = None, expectedOutput = Seq(-1))
+    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
+      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
+
+    // Update state and timestamp, when timeout not present
+    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
+      expectedStates = Seq((0, updatedTime)))
+    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
+      expectedStates = Seq((1, updatedTime)))
+
+    // Remove state
+    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
+      expectedStates = Nil, expectedRemovedStates = Seq(345))
+
+    // State strictly older than timeout threshold should be timed out
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
+      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
+
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Seq(123))
+
+    // State should not be timed out after it has received data
+    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
+    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Nil, 
expectedRemovedStates = Seq(123))
+
+  }
+
   test("states generated by TrackStateRDD") {
     val initStates = Seq(("k1", 0), ("k2", 0))
     val initTime = 123
@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     val rdd7 = testStateUpdates(                      // should remove k2's 
state
       rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
 
-    val rdd8 = testStateUpdates(
-      rdd7, Seq(("k3", 2)), Set()                     //
-    )
+    val rdd8 = testStateUpdates(                      // should remove k3's 
state
+      rdd7, Seq(("k3", 2)), Set())
   }
 
   /** Assert whether the `trackStateByKey` operation generates expected 
results */
@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with 
BeforeAndAfterAll {
 
     // Persist to make sure that it gets computed only once and we can track 
precisely how many
     // state keys the computing touched
-    newStateRDD.persist()
+    newStateRDD.persist().count()
     assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
     newStateRDD
   }
@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with 
BeforeAndAfterAll {
       expectedEmittedRecords: Set[T]): Unit = {
     val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
     val emittedRecords = trackStateRDD.flatMap { _.emittedRecords 
}.collect().toSet
-    assert(states === expectedStates, "states after track state operation were 
not as expected")
+    assert(states === expectedStates,
+      "states after track state operation were not as expected")
     assert(emittedRecords === expectedEmittedRecords,
       "emitted records after track state operation were not as expected")
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to