[SPARK-11290][STREAMING] Basic implementation of trackStateByKey

Current updateStateByKey provides stateful processing in Spark Streaming. It 
allows the user to maintain per-key state and manage that state using an 
updateFunction. The updateFunction is called for each key, and it uses new data 
and existing state of the key, to generate an updated state. However, based on 
community feedback, we have learnt the following lessons.
* Need for more optimized state management that does not scan every key
* Need to make it easier to implement common use cases - (a) timeout of idle 
data, (b) returning items other than state

The high level idea that of this PR
* Introduce a new API trackStateByKey that, allows the user to update per-key 
state, and emit arbitrary records. The new API is necessary as this will have 
significantly different semantics than the existing updateStateByKey API. This 
API will have direct support for timeouts.
* Internally, the system will keep the state data as a map/list within the 
partitions of the state RDDs. The new data RDDs will be partitioned 
appropriately, and for all the key-value data, it will lookup the map/list in 
the state RDD partition and create a new list/map of updated state data. The 
new state RDD partition will be created based on the update data and if 
necessary, with old data.
Here is the detailed design doc. Please take a look and provide feedback as 
comments.
https://docs.google.com/document/d/1NoALLyd83zGs1hNGMm0Pc5YOVgiPpMHugGMk6COqxxE/edit#heading=h.ph3w0clkd4em

This is still WIP. Major things left to be done.
- [x] Implement basic functionality of state tracking, with initial RDD and 
timeouts
- [x] Unit tests for state tracking
- [x] Unit tests for initial RDD and timeout
- [ ] Unit tests for TrackStateRDD
       - [x] state creating, updating, removing
       - [ ] emitting
       - [ ] checkpointing
- [x] Misc unit tests for State, TrackStateSpec, etc.
- [x] Update docs and experimental tags

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #9256 from tdas/trackStateByKey.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/99f5f988
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/99f5f988
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/99f5f988

Branch: refs/heads/master
Commit: 99f5f988612b3093d73d9ce98819767e822fcbff
Parents: bd70244
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue Nov 10 23:16:18 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Nov 10 23:16:18 2015 -0800

----------------------------------------------------------------------
 .../streaming/StatefulNetworkWordCount.scala    |  25 +-
 .../org/apache/spark/streaming/State.scala      | 193 ++++++++
 .../org/apache/spark/streaming/StateSpec.scala  | 212 ++++++++
 .../dstream/PairDStreamFunctions.scala          |  46 +-
 .../streaming/dstream/TrackStateDStream.scala   | 142 ++++++
 .../spark/streaming/rdd/TrackStateRDD.scala     | 188 +++++++
 .../apache/spark/streaming/util/StateMap.scala  | 337 +++++++++++++
 .../apache/spark/streaming/StateMapSuite.scala  | 314 ++++++++++++
 .../spark/streaming/TrackStateByKeySuite.scala  | 494 +++++++++++++++++++
 .../streaming/rdd/TrackStateRDDSuite.scala      | 193 ++++++++
 10 files changed, 2125 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
 
b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index 02ba1c2..be2ae0b 100644
--- 
a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ 
b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -44,18 +44,6 @@ object StatefulNetworkWordCount {
 
     StreamingExamples.setStreamingLogLevels()
 
-    val updateFunc = (values: Seq[Int], state: Option[Int]) => {
-      val currentCount = values.sum
-
-      val previousCount = state.getOrElse(0)
-
-      Some(currentCount + previousCount)
-    }
-
-    val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) 
=> {
-      iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
-    }
-
     val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
     // Create the context with a 1 second batch size
     val ssc = new StreamingContext(sparkConf, Seconds(1))
@@ -71,9 +59,16 @@ object StatefulNetworkWordCount {
     val wordDstream = words.map(x => (x, 1))
 
     // Update the cumulative count using updateStateByKey
-    // This will give a Dstream made of state (which is the cumulative count 
of the words)
-    val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
-      new HashPartitioner (ssc.sparkContext.defaultParallelism), true, 
initialRDD)
+    // This will give a DStream made of state (which is the cumulative count 
of the words)
+    val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], 
state: State[Int]) => {
+      val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (word, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    val stateDstream = wordDstream.trackStateByKey(
+      StateSpec.function(trackStateFunc).initialState(initialRDD))
     stateDstream.print()
     ssc.start()
     ssc.awaitTermination()

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/State.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
new file mode 100644
index 0000000..7dd1b72
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.language.implicitConversions
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Abstract class for getting and updating the tracked state in the 
`trackStateByKey` operation of
+ * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+ *
+ * Scala example of using `State`:
+ * {{{
+ *    // A tracking function that maintains an integer state and return a 
String
+ *    def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] 
= {
+ *      // Check if state exists
+ *      if (state.exists) {
+ *        val existingState = state.get  // Get the existing state
+ *        val shouldRemove = ...         // Decide whether to remove the state
+ *        if (shouldRemove) {
+ *          state.remove()     // Remove the state
+ *        } else {
+ *          val newState = ...
+ *          state.update(newState)    // Set the new state
+ *        }
+ *      } else {
+ *        val initialState = ...
+ *        state.update(initialState)  // Set the initial state
+ *      }
+ *      ... // return something
+ *    }
+ *
+ * }}}
+ *
+ * Java example:
+ * {{{
+ *      TODO(@zsxwing)
+ * }}}
+ */
+@Experimental
+sealed abstract class State[S] {
+
+  /** Whether the state already exists */
+  def exists(): Boolean
+
+  /**
+   * Get the state if it exists, otherwise it will throw 
`java.util.NoSuchElementException`.
+   * Check with `exists()` whether the state exists or not before calling 
`get()`.
+   *
+   * @throws java.util.NoSuchElementException If the state does not exist.
+   */
+  def get(): S
+
+  /**
+   * Update the state with a new value.
+   *
+   * State cannot be updated if it has been already removed (that is, 
`remove()` has already been
+   * called) or it is going to be removed due to timeout (that is, 
`isTimingOut()` is `true`).
+   *
+   * @throws java.lang.IllegalArgumentException If the state has already been 
removed, or is
+   *                                            going to be removed
+   */
+  def update(newState: S): Unit
+
+  /**
+   * Remove the state if it exists.
+   *
+   * State cannot be updated if it has been already removed (that is, 
`remove()` has already been
+   * called) or it is going to be removed due to timeout (that is, 
`isTimingOut()` is `true`).
+   */
+  def remove(): Unit
+
+  /**
+   * Whether the state is timing out and going to be removed by the system 
after the current batch.
+   * This timeout can occur if timeout duration has been specified in the
+   * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not 
received any new data
+   * for that timeout duration.
+   */
+  def isTimingOut(): Boolean
+
+  /**
+   * Get the state as an [[scala.Option]]. It will be `Some(state)` if it 
exists, otherwise `None`.
+   */
+  @inline final def getOption(): Option[S] = if (exists) Some(get()) else None
+
+  @inline final override def toString(): String = {
+    getOption.map { _.toString }.getOrElse("<state not set>")
+  }
+}
+
+/** Internal implementation of the [[State]] interface */
+private[streaming] class StateImpl[S] extends State[S] {
+
+  private var state: S = null.asInstanceOf[S]
+  private var defined: Boolean = false
+  private var timingOut: Boolean = false
+  private var updated: Boolean = false
+  private var removed: Boolean = false
+
+  // ========= Public API =========
+  override def exists(): Boolean = {
+    defined
+  }
+
+  override def get(): S = {
+    if (defined) {
+      state
+    } else {
+      throw new NoSuchElementException("State is not set")
+    }
+  }
+
+  override def update(newState: S): Unit = {
+    require(!removed, "Cannot update the state after it has been removed")
+    require(!timingOut, "Cannot update the state that is timing out")
+    state = newState
+    defined = true
+    updated = true
+  }
+
+  override def isTimingOut(): Boolean = {
+    timingOut
+  }
+
+  override def remove(): Unit = {
+    require(!timingOut, "Cannot remove the state that is timing out")
+    require(!removed, "Cannot remove the state that has already been removed")
+    defined = false
+    updated = false
+    removed = true
+  }
+
+  // ========= Internal API =========
+
+  /** Whether the state has been marked for removing */
+  def isRemoved(): Boolean = {
+    removed
+  }
+
+  /** Whether the state has been been updated */
+  def isUpdated(): Boolean = {
+    updated
+  }
+
+  /**
+   * Update the internal data and flags in `this` to the given state option.
+   * This method allows `this` object to be reused across many state records.
+   */
+  def wrap(optionalState: Option[S]): Unit = {
+    optionalState match {
+      case Some(newState) =>
+        this.state = newState
+        defined = true
+
+      case None =>
+        this.state = null.asInstanceOf[S]
+        defined = false
+    }
+    timingOut = false
+    removed = false
+    updated = false
+  }
+
+  /**
+   * Update the internal data and flags in `this` to the given state that is 
going to be timed out.
+   * This method allows `this` object to be reused across many state records.
+   */
+  def wrapTiminoutState(newState: S): Unit = {
+    this.state = newState
+    defined = true
+    timingOut = true
+    removed = false
+    updated = false
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
new file mode 100644
index 0000000..c9fe35e
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.{HashPartitioner, Partitioner}
+
+
+/**
+ * :: Experimental ::
+ * Abstract class representing all the specifications of the DStream 
transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+ * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
+ * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create 
instances of
+ * this class.
+ *
+ * Example in Scala:
+ * {{{
+ *    def trackingFunction(data: Option[ValueType], wrappedState: 
State[StateType]): EmittedType = {
+ *      ...
+ *    }
+ *
+ *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, 
EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ *      StateStateSpec.function[KeyType, ValueType, StateType, 
EmittedDataType](trackingFunction)
+ *                    .numPartition(10);
+ *
+ *    JavaDStream[EmittedDataType] emittedRecordDStream =
+ *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] 
extends Serializable {
+
+  /** Set the RDD containing the initial states that will be used by 
`trackStateByKey` */
+  def initialState(rdd: RDD[(KeyType, StateType)]): this.type
+
+  /** Set the RDD containing the initial states that will be used by 
`trackStateByKey` */
+  def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
+
+  /**
+   * Set the number of partitions by which the state RDDs generated by 
`trackStateByKey`
+   * will be partitioned. Hash partitioning will be used.
+   */
+  def numPartitions(numPartitions: Int): this.type
+
+  /**
+   * Set the partitioner by which the state RDDs generated by 
`trackStateByKey` will be
+   * be partitioned.
+   */
+  def partitioner(partitioner: Partitioner): this.type
+
+  /**
+   * Set the duration after which the state of an idle key will be removed. A 
key and its state is
+   * considered idle if it has not received any data for at least the given 
duration. The state
+   * tracking function will be called one final time on the idle states that 
are going to be
+   * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
+   * to `true` in that call.
+   */
+  def timeout(idleDuration: Duration): this.type
+}
+
+
+/**
+ * :: Experimental ::
+ * Builder object for creating instances of 
[[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+ *
+ * Example in Scala:
+ * {{{
+ *    def trackingFunction(data: Option[ValueType], wrappedState: 
State[StateType]): EmittedType = {
+ *      ...
+ *    }
+ *
+ *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, 
EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ *      StateStateSpec.function[KeyType, ValueType, StateType, 
EmittedDataType](trackingFunction)
+ *                    .numPartition(10);
+ *
+ *    JavaDStream[EmittedDataType] emittedRecordDStream =
+ *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+object StateSpec {
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all the specifications
+   * `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+   * @param trackingFunction The function applied on every data item to manage 
the associated state
+   *                         and generate the emitted data
+   * @tparam KeyType      Class of the keys
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, EmittedType](
+      trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) 
=> Option[EmittedType]
+    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+    new StateSpecImpl(trackingFunction)
+  }
+
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all the specifications
+   * `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+   * @param trackingFunction The function applied on every data item to manage 
the associated state
+   *                         and generate the emitted data
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, EmittedType](
+      trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
+    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+    val wrappedFunction =
+      (time: Time, key: Any, value: Option[ValueType], state: 
State[StateType]) => {
+        Some(trackingFunction(value, state))
+      }
+    new StateSpecImpl(wrappedFunction)
+  }
+}
+
+
+/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] 
interface. */
+private[streaming]
+case class StateSpecImpl[K, V, S, T](
+    function: (Time, K, Option[V], State[S]) => Option[T]) extends 
StateSpec[K, V, S, T] {
+
+  require(function != null)
+
+  @volatile private var partitioner: Partitioner = null
+  @volatile private var initialStateRDD: RDD[(K, S)] = null
+  @volatile private var timeoutInterval: Duration = null
+
+  override def initialState(rdd: RDD[(K, S)]): this.type = {
+    this.initialStateRDD = rdd
+    this
+  }
+
+  override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = {
+    this.initialStateRDD = javaPairRDD.rdd
+    this
+  }
+
+
+  override def numPartitions(numPartitions: Int): this.type = {
+    this.partitioner(new HashPartitioner(numPartitions))
+    this
+  }
+
+  override def partitioner(partitioner: Partitioner): this.type = {
+    this.partitioner = partitioner
+    this
+  }
+
+  override def timeout(interval: Duration): this.type = {
+    this.timeoutInterval = interval
+    this
+  }
+
+  // ================= Private Methods =================
+
+  private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => 
Option[T] = function
+
+  private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = 
Option(initialStateRDD)
+
+  private[streaming] def getPartitioner(): Option[Partitioner] = 
Option(partitioner)
+
+  private[streaming] def getTimeoutInterval(): Option[Duration] = 
Option(timeoutInterval)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 71bec96..fb691ee 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.mapred.{JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
 
-import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, Time}
 import org.apache.spark.streaming.StreamingContext.rddToFileName
+import org.apache.spark.streaming._
 import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf}
+import org.apache.spark.{HashPartitioner, Partitioner}
 
 /**
  * Extra functions available on DStream of (key, value) pairs through an 
implicit conversion.
  */
 class PairDStreamFunctions[K, V](self: DStream[(K, V)])
     (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K])
-  extends Serializable
-{
+  extends Serializable {
   private[streaming] def ssc = self.ssc
 
   private[streaming] def sparkContext = self.context.sparkContext
@@ -351,6 +351,44 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
   }
 
   /**
+   * :: Experimental ::
+   * Return a new DStream of data generated by combining the key-value data in 
`this` stream
+   * with a continuously updated per-key state. The user-provided state 
tracking function is
+   * applied on each keyed data item along with its corresponding state. The 
function can choose to
+   * update/remove the state and return a transformed data, which forms the
+   * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+   *
+   * The specifications of this transformation is made through the
+   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the 
tracking function, there
+   * are a number of optional parameters - initial state data, number of 
partitions, timeouts, etc.
+   * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for 
more details.
+   *
+   * Example of using `trackStateByKey`:
+   * {{{
+   *    def trackingFunction(data: Option[Int], wrappedState: State[Int]): 
String = {
+   *      // Check if state exists, accordingly update/remove state and return 
transformed data
+   *    }
+   *
+   *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+   *
+   *    val trackStateDStream = keyValueDStream.trackStateByKey[Int, 
String](spec)
+   * }}}
+   *
+   * @param spec          Specification of this transformation
+   * @tparam StateType    Class type of the state
+   * @tparam EmittedType  Class type of the tranformed data return by the 
tracking function
+   */
+  @Experimental
+  def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
+      spec: StateSpec[K, V, StateType, EmittedType]
+    ): TrackStateDStream[K, V, StateType, EmittedType] = {
+    new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+      self,
+      spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+    )
+  }
+
+  /**
    * Return a new "state" DStream where the state for each key is updated by 
applying
    * the given function on the previous state of the key and the new values of 
each key.
    * Hash partitioning is used to generate the RDDs with Spark's default 
number of partitions.

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
new file mode 100644
index 0000000..58d89c9
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
+
+/**
+ * :: Experimental ::
+ * DStream representing the stream of records emitted by the tracking function 
in the
+ * `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ * Additionally, it also gives access to the stream of state snapshots, that 
is, the state data of
+ * all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam StateType Class of the state data
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, 
EmittedType: ClassTag](
+    ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all 
the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)]
+}
+
+/** Internal implementation of the [[TrackStateDStream]] */
+private[streaming] class TrackStateDStreamImpl[
+    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: 
ClassTag](
+    dataStream: DStream[(KeyType, ValueType)],
+    spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
+  extends TrackStateDStream[KeyType, ValueType, StateType, 
EmittedType](dataStream.context) {
+
+  private val internalStream =
+    new InternalTrackStateDStream[KeyType, ValueType, StateType, 
EmittedType](dataStream, spec)
+
+  override def slideDuration: Duration = internalStream.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(internalStream)
+
+  override def compute(validTime: Time): Option[RDD[EmittedType]] = {
+    internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { 
_.emittedRecords } }
+  }
+
+  /**
+   * Forward the checkpoint interval to the internal DStream that computes the 
state maps. This
+   * to make sure that this DStream does not get checkpointed, only the 
internal stream.
+   */
+  override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] 
= {
+    internalStream.checkpoint(checkpointInterval)
+    this
+  }
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all 
the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)] = {
+    internalStream.flatMap {
+      _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
+  }
+
+  def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
+
+  def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
+
+  def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
+
+  def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+}
+
+/**
+ * A DStream that allows per-key state to be maintains, and arbitrary records 
to be generated
+ * based on updates to the state. This is the main DStream that implements the 
`trackStateByKey`
+ * operation on DStreams.
+ *
+ * @param parent Parent (key, value) stream that is the source
+ * @param spec Specifications of the trackStateByKey operation
+ * @tparam K   Key type
+ * @tparam V   Value type
+ * @tparam S   Type of the state maintained
+ * @tparam E   Type of the emitted data
+ */
+private[streaming]
+class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: 
ClassTag](
+    parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
+  extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+
+  persist(StorageLevel.MEMORY_ONLY)
+
+  private val partitioner = spec.getPartitioner().getOrElse(
+    new HashPartitioner(ssc.sc.defaultParallelism))
+
+  private val trackingFunction = spec.getFunction()
+
+  override def slideDuration: Duration = parent.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(parent)
+
+  /** Enable automatic checkpointing */
+  override val mustCheckpoint = true
+
+  /** Method that generates a RDD for the given time */
+  override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, 
E]]] = {
+    // Get the previous state or create a new empty state RDD
+    val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
+      TrackStateRDD.createFromPairRDD[K, V, S, E](
+        spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, 
S)](ssc.sparkContext)),
+        partitioner, validTime
+      )
+    }
+
+    // Compute the new state RDD with previous state RDD and partitioned data 
RDD
+    parent.getOrCompute(validTime).map { dataRDD =>
+      val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+      val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+        (validTime - interval).milliseconds
+      }
+      new TrackStateRDD(
+        prevStateRDD, partitionedDataRDD, trackingFunction, validTime, 
timeoutThresholdTime)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
new file mode 100644
index 0000000..ed7cea2
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.streaming.{Time, StateImpl, State}
+import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
+import org.apache.spark.util.Utils
+import org.apache.spark._
+
+/**
+ * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a 
[[StateMap]] and a
+ * sequence of records returned by the tracking function of `trackStateByKey`.
+ */
+private[streaming] case class TrackStateRDDRecord[K, S, T](
+    var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+
+/**
+ * Partition of the [[TrackStateRDD]], which depends on corresponding 
partitions of prev state
+ * RDD, and a partitioned keyed-data RDD
+ */
+private[streaming] class TrackStateRDDPartition(
+    idx: Int,
+    @transient private var prevStateRDD: RDD[_],
+    @transient private var partitionedDataRDD: RDD[_]) extends Partition {
+
+  private[rdd] var previousSessionRDDPartition: Partition = null
+  private[rdd] var partitionedDataRDDPartition: Partition = null
+
+  override def index: Int = idx
+  override def hashCode(): Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream): Unit = 
Utils.tryOrIOException {
+    // Update the reference to parent split at the time of task serialization
+    previousSessionRDDPartition = prevStateRDD.partitions(index)
+    partitionedDataRDDPartition = partitionedDataRDD.partitions(index)
+    oos.defaultWriteObject()
+  }
+}
+
+
+/**
+ * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted 
records.
+ * Each partition of this RDD has a single record of type 
[[TrackStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records 
returned by the tracking
+ * function of  `trackStateByKey`.
+ * @param prevStateRDD The previous TrackStateRDD on whose StateMap data 
`this` RDD will be created
+ * @param partitionedDataRDD The partitioned data RDD which is used update the 
previous StateMaps
+ *                           in the `prevStateRDD` to create `this` RDD
+ * @param trackingFunction The function that will be used to update state and 
return new data
+ * @param batchTime        The time of the batch to which this RDD belongs to. 
Use to update
+ */
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, 
T: ClassTag](
+    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+    private var partitionedDataRDD: RDD[(K, V)],
+    trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+    batchTime: Time, timeoutThresholdTime: Option[Long]
+  ) extends RDD[TrackStateRDDRecord[K, S, T]](
+    partitionedDataRDD.sparkContext,
+    List(
+      new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+      new OneToOneDependency(partitionedDataRDD))
+  ) {
+
+  @volatile private var doFullScan = false
+
+  require(prevStateRDD.partitioner.nonEmpty)
+  require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)
+
+  override val partitioner = prevStateRDD.partitioner
+
+  override def checkpoint(): Unit = {
+    super.checkpoint()
+    doFullScan = true
+  }
+
+  override def compute(
+      partition: Partition, context: TaskContext): 
Iterator[TrackStateRDDRecord[K, S, T]] = {
+
+    val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+    val prevStateRDDIterator = prevStateRDD.iterator(
+      stateRDDPartition.previousSessionRDDPartition, context)
+    val dataIterator = partitionedDataRDD.iterator(
+      stateRDDPartition.partitionedDataRDDPartition, context)
+
+    // Create a new state map by cloning the previous one (if it exists) or by 
creating an empty one
+    val newStateMap = if (prevStateRDDIterator.hasNext) {
+      prevStateRDDIterator.next().stateMap.copy()
+    } else {
+      new EmptyStateMap[K, S]()
+    }
+
+    val emittedRecords = new ArrayBuffer[T]
+    val wrappedState = new StateImpl[S]()
+
+    // Call the tracking function on each record in the data RDD partition, 
and accordingly
+    // update the states touched, and the data returned by the tracking 
function.
+    dataIterator.foreach { case (key, value) =>
+      wrappedState.wrap(newStateMap.get(key))
+      val emittedRecord = trackingFunction(batchTime, key, Some(value), 
wrappedState)
+      if (wrappedState.isRemoved) {
+        newStateMap.remove(key)
+      } else if (wrappedState.isUpdated) {
+        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+      }
+      emittedRecords ++= emittedRecord
+    }
+
+    // If the RDD is expected to be doing a full scan of all the data in the 
StateMap,
+    // then use this opportunity to filter out those keys that have timed out.
+    // For each of them call the tracking function.
+    if (doFullScan && timeoutThresholdTime.isDefined) {
+      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, 
state, _) =>
+        wrappedState.wrapTiminoutState(state)
+        val emittedRecord = trackingFunction(batchTime, key, None, 
wrappedState)
+        emittedRecords ++= emittedRecord
+        newStateMap.remove(key)
+      }
+    }
+
+    Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+  }
+
+  override protected def getPartitions: Array[Partition] = {
+    Array.tabulate(prevStateRDD.partitions.length) { i =>
+      new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+  }
+
+  override def clearDependencies(): Unit = {
+    super.clearDependencies()
+    prevStateRDD = null
+    partitionedDataRDD = null
+  }
+
+  def setFullScan(): Unit = {
+    doFullScan = true
+  }
+}
+
+private[streaming] object TrackStateRDD {
+
+  def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      pairRDD: RDD[(K, S)],
+      partitioner: Partitioner,
+      updateTime: Time): TrackStateRDD[K, V, S, T] = {
+
+    val rddOfTrackStateRecords = 
pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+      iterator.foreach { case (key, state) => stateMap.put(key, state, 
updateTime.milliseconds) }
+      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
+    }, preservesPartitioning = true)
+
+    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, 
V)].partitionBy(partitioner)
+
+    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => 
None
+
+    new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, 
noOpFunc, updateTime, None)
+  }
+}
+
+private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: 
ClassTag, T: ClassTag](
+    parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
+  override protected def getPartitions: Array[Partition] = parent.partitions
+  override def compute(partition: Partition, context: TaskContext): 
Iterator[T] = {
+    parent.compute(partition, context).flatMap { _.emittedRecords }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
new file mode 100644
index 0000000..ed622ef
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
+import org.apache.spark.util.collection.OpenHashMap
+
+/** Internal interface for defining the map that keeps track of sessions. */
+private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends 
Serializable {
+
+  /** Get the state for a key if it exists */
+  def get(key: K): Option[S]
+
+  /** Get all the keys and states whose updated time is older than the given 
threshold time */
+  def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)]
+
+  /** Get all the keys and states in this map. */
+  def getAll(): Iterator[(K, S, Long)]
+
+  /** Add or update state */
+  def put(key: K, state: S, updatedTime: Long): Unit
+
+  /** Remove a key */
+  def remove(key: K): Unit
+
+  /**
+   * Shallow copy `this` map to create a new state map.
+   * Updates to the new map should not mutate `this` map.
+   */
+  def copy(): StateMap[K, S]
+
+  def toDebugString(): String = toString()
+}
+
+/** Companion object for [[StateMap]], with utility methods */
+private[streaming] object StateMap {
+  def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+
+  def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
+    val deltaChainThreshold = 
conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
+      DELTA_CHAIN_LENGTH_THRESHOLD)
+    new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+  }
+}
+
+/** Implementation of StateMap interface representing an empty map */
+private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends 
StateMap[K, S] {
+  override def put(key: K, session: S, updateTime: Long): Unit = {
+    throw new NotImplementedError("put() should not be called on an 
EmptyStateMap")
+  }
+  override def get(key: K): Option[S] = None
+  override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = 
Iterator.empty
+  override def getAll(): Iterator[(K, S, Long)] = Iterator.empty
+  override def copy(): StateMap[K, S] = this
+  override def remove(key: K): Unit = { }
+  override def toDebugString(): String = ""
+}
+
+/** Implementation of StateMap based on Spark's 
[[org.apache.spark.util.collection.OpenHashMap]] */
+private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+    @transient @volatile var parentStateMap: StateMap[K, S],
+    initialCapacity: Int = 64,
+    deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+  ) extends StateMap[K, S] { self =>
+
+  def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+    new EmptyStateMap[K, S],
+    initialCapacity = initialCapacity,
+    deltaChainThreshold = deltaChainThreshold)
+
+  def this(deltaChainThreshold: Int) = this(
+    initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+
+  def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+
+  @transient @volatile private var deltaMap =
+    new OpenHashMap[K, StateInfo[S]](initialCapacity)
+
+  /** Get the session data if it exists */
+  override def get(key: K): Option[S] = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      if (!stateInfo.deleted) {
+        Some(stateInfo.data)
+      } else {
+        None
+      }
+    } else {
+      parentStateMap.get(key)
+    }
+  }
+
+  /** Get all the keys and states whose updated time is older than the give 
threshold time */
+  override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = {
+    val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case 
(key, value, _) =>
+      !deltaMap.contains(key)
+    }
+
+    val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) =>
+      !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime
+    }.map { case (key, stateInfo) =>
+      (key, stateInfo.data, stateInfo.updateTime)
+    }
+    oldStates ++ updatedStates
+  }
+
+  /** Get all the keys and states in this map. */
+  override def getAll(): Iterator[(K, S, Long)] = {
+
+    val oldStates = parentStateMap.getAll().filter { case (key, _, _) =>
+      !deltaMap.contains(key)
+    }
+
+    val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case 
(key, stateInfo) =>
+      (key, stateInfo.data, stateInfo.updateTime)
+    }
+    oldStates ++ updatedStates
+  }
+
+  /** Add or update state */
+  override def put(key: K, state: S, updateTime: Long): Unit = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      stateInfo.update(state, updateTime)
+    } else {
+      deltaMap.update(key, new StateInfo(state, updateTime))
+    }
+  }
+
+  /** Remove a state */
+  override def remove(key: K): Unit = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      stateInfo.markDeleted()
+    } else {
+      val newInfo = new StateInfo[S](deleted = true)
+      deltaMap.update(key, newInfo)
+    }
+  }
+
+  /**
+   * Shallow copy the map to create a new session store. Updates to the new map
+   * should not mutate `this` map.
+   */
+  override def copy(): StateMap[K, S] = {
+    new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = 
deltaChainThreshold)
+  }
+
+  /** Whether the delta chain lenght is long enough that it should be 
compacted */
+  def shouldCompact: Boolean = {
+    deltaChainLength >= deltaChainThreshold
+  }
+
+  /** Length of the delta chains of this map */
+  def deltaChainLength: Int = parentStateMap match {
+    case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1
+    case _ => 0
+  }
+
+  /**
+   * Approximate number of keys in the map. This is an overestimation that is 
mainly used to
+   * reserve capacity in a new map at delta compaction time.
+   */
+  def approxSize: Int = deltaMap.size + {
+    parentStateMap match {
+      case s: OpenHashMapBasedStateMap[_, _] => s.approxSize
+      case _ => 0
+    }
+  }
+
+  /** Get all the data of this map as string formatted as a tree based on the 
delta depth */
+  override def toDebugString(): String = {
+    val tabs = if (deltaChainLength > 0) {
+      ("    " * (deltaChainLength - 1)) + "+--- "
+    } else ""
+    parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, 
"\n" + tabs, "")
+  }
+
+  override def toString(): String = {
+    s"[${System.identityHashCode(this)}, 
${System.identityHashCode(parentStateMap)}]"
+  }
+
+  /**
+   * Serialize the map data. Besides serialization, this method actually 
compact the deltas
+   * (if needed) in a single pass over all the data in the map.
+   */
+
+  private def writeObject(outputStream: ObjectOutputStream): Unit = {
+    // Write all the non-transient fields, especially class tags, etc.
+    outputStream.defaultWriteObject()
+
+    // Write the data in the delta of this state map
+    outputStream.writeInt(deltaMap.size)
+    val deltaMapIterator = deltaMap.iterator
+    var deltaMapCount = 0
+    while (deltaMapIterator.hasNext) {
+      deltaMapCount += 1
+      val (key, stateInfo) = deltaMapIterator.next()
+      outputStream.writeObject(key)
+      outputStream.writeObject(stateInfo)
+    }
+    assert(deltaMapCount == deltaMap.size)
+
+    // Write the data in the parent state map while copying the data into a 
new parent map for
+    // compaction (if needed)
+    val doCompaction = shouldCompact
+    val newParentSessionStore = if (doCompaction) {
+      val initCapacity = if (approxSize > 0) approxSize else 64
+      new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, 
deltaChainThreshold)
+    } else { null }
+
+    val iterOfActiveSessions = parentStateMap.getAll()
+
+    var parentSessionCount = 0
+
+    // First write the approximate size of the data to be written, so that 
readObject can
+    // allocate appropriately sized OpenHashMap.
+    outputStream.writeInt(approxSize)
+
+    while(iterOfActiveSessions.hasNext) {
+      parentSessionCount += 1
+
+      val (key, state, updateTime) = iterOfActiveSessions.next()
+      outputStream.writeObject(key)
+      outputStream.writeObject(state)
+      outputStream.writeLong(updateTime)
+
+      if (doCompaction) {
+        newParentSessionStore.deltaMap.update(
+          key, StateInfo(state, updateTime, deleted = false))
+      }
+    }
+
+    // Write the final limit marking object with the correct count of records 
written.
+    val limiterObj = new LimitMarker(parentSessionCount)
+    outputStream.writeObject(limiterObj)
+    if (doCompaction) {
+      parentStateMap = newParentSessionStore
+    }
+  }
+
+  /** Deserialize the map data. */
+  private def readObject(inputStream: ObjectInputStream): Unit = {
+
+    // Read the non-transient fields, especially class tags, etc.
+    inputStream.defaultReadObject()
+
+    // Read the data of the delta
+    val deltaMapSize = inputStream.readInt()
+    deltaMap = new OpenHashMap[K, StateInfo[S]]()
+    var deltaMapCount = 0
+    while (deltaMapCount < deltaMapSize) {
+      val key = inputStream.readObject().asInstanceOf[K]
+      val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]]
+      deltaMap.update(key, sessionInfo)
+      deltaMapCount += 1
+    }
+
+
+    // Read the data of the parent map. Keep reading records, until the 
limiter is reached
+    // First read the approximate number of records to expect and allocate 
properly size
+    // OpenHashMap
+    val parentSessionStoreSizeHint = inputStream.readInt()
+    val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
+      initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+
+    // Read the records until the limit marking object has been reached
+    var parentSessionLoopDone = false
+    while(!parentSessionLoopDone) {
+      val obj = inputStream.readObject()
+      if (obj.isInstanceOf[LimitMarker]) {
+        parentSessionLoopDone = true
+        val expectedCount = obj.asInstanceOf[LimitMarker].num
+        assert(expectedCount == newParentSessionStore.deltaMap.size)
+      } else {
+        val key = obj.asInstanceOf[K]
+        val state = inputStream.readObject().asInstanceOf[S]
+        val updateTime = inputStream.readLong()
+        newParentSessionStore.deltaMap.update(
+          key, StateInfo(state, updateTime, deleted = false))
+      }
+    }
+    parentStateMap = newParentSessionStore
+  }
+}
+
+/**
+ * Companion object of [[OpenHashMapBasedStateMap]] having associated helper
+ * classes and methods
+ */
+private[streaming] object OpenHashMapBasedStateMap {
+
+  /** Internal class to represent the state information */
+  case class StateInfo[S](
+      var data: S = null.asInstanceOf[S],
+      var updateTime: Long = -1,
+      var deleted: Boolean = false) {
+
+    def markDeleted(): Unit = {
+      deleted = true
+    }
+
+    def update(newData: S, newUpdateTime: Long): Unit = {
+      data = newData
+      updateTime = newUpdateTime
+      deleted = false
+    }
+  }
+
+  /**
+   * Internal class to represent a marker the demarkate the the end of all 
state data in the
+   * serialized bytes.
+   */
+  class LimitMarker(val num: Int) extends Serializable
+
+  val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
new file mode 100644
index 0000000..48d3b41
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.collection.{immutable, mutable, Map}
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.streaming.util.{EmptyStateMap, 
OpenHashMapBasedStateMap, StateMap}
+import org.apache.spark.util.Utils
+
+class StateMapSuite extends SparkFunSuite {
+
+  test("EmptyStateMap") {
+    val map = new EmptyStateMap[Int, Int]
+    intercept[scala.NotImplementedError] {
+      map.put(1, 1, 1)
+    }
+    assert(map.get(1) === None)
+    assert(map.getByTime(10000).isEmpty)
+    assert(map.getAll().isEmpty)
+    map.remove(1)   // no exception
+    assert(map.copy().eq(map))
+  }
+
+  test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") {
+    val map = new OpenHashMapBasedStateMap[Int, Int]()
+
+    map.put(1, 100, 10)
+    assert(map.get(1) === Some(100))
+    assert(map.get(2) === None)
+    assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+    assert(map.getByTime(10).toSet === Set.empty)
+    assert(map.getByTime(9).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((1, 100, 10)))
+
+    map.put(2, 200, 20)
+    assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20)))
+    assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+    assert(map.getByTime(10).toSet === Set.empty)
+    assert(map.getByTime(9).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20)))
+
+    map.remove(1)
+    assert(map.get(1) === None)
+    assert(map.getAll().toSet === Set((2, 200, 20)))
+  }
+
+  test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with 
copy") {
+    val parentMap = new OpenHashMapBasedStateMap[Int, Int]()
+    parentMap.put(1, 100, 1)
+    parentMap.put(2, 200, 2)
+    parentMap.remove(1)
+
+    // Create child map and make changes
+    val map = parentMap.copy()
+    assert(map.get(1) === None)
+    assert(map.get(2) === Some(200))
+    assert(map.getByTime(10).toSet === Set((2, 200, 2)))
+    assert(map.getByTime(2).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((2, 200, 2)))
+
+    // Add new items
+    map.put(3, 300, 3)
+    assert(map.get(3) === Some(300))
+    map.put(4, 400, 4)
+    assert(map.get(4) === Some(400))
+    assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 
4)))
+    assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3)))
+    assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    // Remove items
+    map.remove(4)
+    assert(map.get(4) === None)       // item added in this map, then removed 
in this map
+    map.remove(2)
+    assert(map.get(2) === None)       // item removed in parent map, then 
added in this map
+    assert(map.getAll().toSet === Set((3, 300, 3)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    // Update items
+    map.put(1, 1000, 100)
+    assert(map.get(1) === Some(1000)) // item removed in parent map, then 
added in this map
+    map.put(2, 2000, 200)
+    assert(map.get(2) === Some(2000)) // item added in parent map, then 
removed + added in this map
+    map.put(3, 3000, 300)
+    assert(map.get(3) === Some(3000)) // item added + updated in this map
+    map.put(4, 4000, 400)
+    assert(map.get(4) === Some(4000)) // item removed + updated in this map
+
+    assert(map.getAll().toSet ===
+      Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    map.remove(2)         // remove item present in parent map, so that its 
not visible in child map
+
+    // Create child map and see availability of items
+    val childMap = map.copy()
+    assert(childMap.getAll().toSet === map.getAll().toSet)
+    assert(childMap.get(1) === Some(1000))  // item removed in grandparent, 
but added in parent map
+    assert(childMap.get(2) === None)        // item added in grandparent, but 
removed in parent map
+    assert(childMap.get(3) === Some(3000))  // item added and updated in 
parent map
+
+    childMap.put(2, 20000, 200)
+    assert(childMap.get(2) === Some(20000)) // item map
+  }
+
+  test("OpenHashMapBasedStateMap - serializing and deserializing") {
+    val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+    map1.put(1, 100, 1)
+    map1.put(2, 200, 2)
+
+    val map2 = map1.copy()
+    map2.put(3, 300, 3)
+    map2.put(4, 400, 4)
+
+    val map3 = map2.copy()
+    map3.put(3, 600, 3)
+    map3.remove(2)
+
+    // Do not test compaction
+    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === 
false)
+
+    val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
+      Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
+    assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+  }
+
+  test("OpenHashMapBasedStateMap - serializing and deserializing with 
compaction") {
+    val targetDeltaLength = 10
+    val deltaChainThreshold = 5
+
+    var map = new OpenHashMapBasedStateMap[Int, Int](
+      deltaChainThreshold = deltaChainThreshold)
+
+    // Make large delta chain with length more than deltaChainThreshold
+    for(i <- 1 to targetDeltaLength) {
+      map.put(Random.nextInt(), Random.nextInt(), 1)
+      map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
+    }
+    assert(map.deltaChainLength > deltaChainThreshold)
+    assert(map.shouldCompact === true)
+
+    val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
+      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    assert(deser_map.deltaChainLength < deltaChainThreshold)
+    assert(deser_map.shouldCompact === false)
+    assertMap(deser_map, map, 1, "Deserialized + compacted map not same as 
original map")
+  }
+
+  test("OpenHashMapBasedStateMap - all possible sequences of operations with 
copies ") {
+    /*
+     * This tests the map using all permutations of sequences operations, 
across multiple map
+     * copies as well as between copies. It is to ensure complete coverage, 
though it is
+     * kind of hard to debug this. It is set up as follows.
+     *
+     * - For any key, there can be 2 types of update ops on a state map - put 
or remove
+     *
+     * - These operations are done on a test map in "sets". After each set, 
the map is "copied"
+     *   to create a new map, and the next set of operations are done on the 
new one. This tests
+     *   whether the map data persistes correctly across copies.
+     *
+     * - Within each set, there are a number of operations to test whether the 
map correctly
+     *   updates and removes data without affecting the parent state map.
+     *
+     * - Overall this creates (numSets * numOpsPerSet) operations, each of 
which that can 2 types
+     *   of operations. This leads to a total of [2 ^ (numSets * 
numOpsPerSet)] different sequence
+     *   of operations, which we will test with different keys.
+     *
+     * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. 
This means that
+     * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys.
+     * _______________________________________________
+     * |         |      Set1       |     Set2        |
+     * |         |-----------------|-----------------|
+     * |         |   Op1    Op2   |c|   Op3    Op4   |
+     * |---------|----------------|o|----------------|
+     * | key 0   |   put    put   |p|   put    put   |
+     * | key 1   |   put    put   |y|   put    rem   |
+     * | key 2   |   put    put   | |   rem    put   |
+     * | key 3   |   put    put   |t|   rem    rem   |
+     * | key 4   |   put    rem   |h|   put    put   |
+     * | key 5   |   put    rem   |e|   put    rem   |
+     * | key 6   |   put    rem   | |   rem    put   |
+     * | key 7   |   put    rem   |s|   rem    rem   |
+     * | key 8   |   rem    put   |t|   put    put   |
+     * | key 9   |   rem    put   |a|   put    rem   |
+     * | key 10  |   rem    put   |t|   rem    put   |
+     * | key 11  |   rem    put   |e|   rem    rem   |
+     * | key 12  |   rem    rem   | |   put    put   |
+     * | key 13  |   rem    rem   |m|   put    rem   |
+     * | key 14  |   rem    rem   |a|   rem    put   |
+     * | key 15  |   rem    rem   |p|   rem    rem   |
+     * |_________|________________|_|________________|
+     */
+
+    val numTypeMapOps = 2   // 0 = put a new value, 1 = remove value
+    val numSets = 3
+    val numOpsPerSet = 3    // to test seq of ops like update -> remove -> 
update in same set
+    val numTotalOps = numOpsPerSet * numSets
+    val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt  // to get all 
combinations of ops
+
+    val refMap = new mutable.HashMap[Int, (Int, Long)]()
+    var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null
+
+    var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]()
+    var prevSetStateMap: StateMap[Int, Int] = null
+
+    var time = 1L
+
+    for (setId <- 0 until numSets) {
+      for (opInSetId <- 0 until numOpsPerSet) {
+        val opId = setId * numOpsPerSet + opInSetId
+        for (keyId <- 0 until numKeys) {
+          time += 1
+          // Find the operation type that needs to be done
+          // This is similar to finding the nth bit value of a binary number
+          // E.g.  nth bit from the right of any binary number B is [ B / (2 ^ 
(n - 1)) ] % 2
+          val opCode =
+            (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % 
numTypeMapOps
+          opCode match {
+            case 0 =>
+              val value = Random.nextInt()
+              stateMap.put(keyId, value, time)
+              refMap.put(keyId, (value, time))
+            case 1 =>
+              stateMap.remove(keyId)
+              refMap.remove(keyId)
+          }
+        }
+
+        // Test whether the current state map after all key updates is correct
+        assertMap(stateMap, refMap, time, "State map does not match reference 
map")
+
+        // Test whether the previous map before copy has not changed
+        if (prevSetStateMap != null && prevSetRefMap != null) {
+          assertMap(prevSetStateMap, prevSetRefMap, time,
+            "Parent state map somehow got modified, does not match 
corresponding reference map")
+        }
+      }
+
+      // Copy the map and remember the previous maps for future tests
+      prevSetStateMap = stateMap
+      prevSetRefMap = refMap.toMap
+      stateMap = stateMap.copy()
+
+      // Assert that the copied map has the same data
+      assertMap(stateMap, prevSetRefMap, time,
+        "State map does not match reference map after copying")
+    }
+    assertMap(stateMap, refMap.toMap, time, "Final state map does not match 
reference map")
+  }
+
+  // Assert whether all the data and operations on a state map matches that of 
a reference state map
+  private def assertMap(
+      mapToTest: StateMap[Int, Int],
+      refMapToTestWith: StateMap[Int, Int],
+      time: Long,
+      msg: String): Unit = {
+    withClue(msg) {
+      // Assert all the data is same as the reference map
+      assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet)
+
+      // Assert that get on every key returns the right value
+      for (keyId <- refMapToTestWith.getAll().map { _._1 }) {
+        assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId))
+      }
+
+      // Assert that every time threshold returns the correct data
+      for (t <- 0L to (time + 1)) {
+        assert(mapToTest.getByTime(t).toSet ===  
refMapToTestWith.getByTime(t).toSet)
+      }
+    }
+  }
+
+  // Assert whether all the data and operations on a state map matches that of 
a reference map
+  private def assertMap(
+      mapToTest: StateMap[Int, Int],
+      refMapToTestWith: Map[Int, (Int, Long)],
+      time: Long,
+      msg: String): Unit = {
+    withClue(msg) {
+      // Assert all the data is same as the reference map
+      assert(mapToTest.getAll().toSet ===
+        refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet)
+
+      // Assert that get on every key returns the right value
+      for (keyId <- refMapToTestWith.keys) {
+        assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 
})
+      }
+
+      // Assert that every time threshold returns the correct data
+      for (t <- 0L to (time + 1)) {
+        val expectedRecords =
+          refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, 
x._2._1, x._2._2) }
+        assert(mapToTest.getByTime(t).toSet ===  expectedRecords.toSet)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99f5f988/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
new file mode 100644
index 0000000..e3072b4
--- /dev/null
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.streaming.dstream.{TrackStateDStream, 
TrackStateDStreamImpl}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with 
BeforeAndAfter {
+
+  private var sc: SparkContext = null
+  private var ssc: StreamingContext = null
+  private var checkpointDir: File = null
+  private val batchDuration = Seconds(1)
+
+  before {
+    StreamingContext.getActive().foreach {
+      _.stop(stopSparkContext = false)
+    }
+    checkpointDir = Utils.createTempDir("checkpoint")
+
+    ssc = new StreamingContext(sc, batchDuration)
+    ssc.checkpoint(checkpointDir.toString)
+  }
+
+  after {
+    StreamingContext.getActive().foreach {
+      _.stop(stopSparkContext = false)
+    }
+  }
+
+  override def beforeAll(): Unit = {
+    val conf = new 
SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
+    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+    sc = new SparkContext(conf)
+  }
+
+  test("state - get, exists, update, remove, ") {
+    var state: StateImpl[Int] = null
+
+    def testState(
+        expectedData: Option[Int],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false,
+        shouldBeTimingOut: Boolean = false
+      ): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get() === expectedData.get)
+        assert(state.getOption() === expectedData)
+        assert(state.getOption.getOrElse(-1) === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get()
+        }
+        assert(state.getOption() === None)
+        assert(state.getOption.getOrElse(-1) === -1)
+      }
+
+      assert(state.isTimingOut() === shouldBeTimingOut)
+      if (shouldBeTimingOut) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+
+      assert(state.isUpdated() === shouldBeUpdated)
+
+      assert(state.isRemoved() === shouldBeRemoved)
+      if (shouldBeRemoved) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+    }
+
+    state = new StateImpl[Int]()
+    testState(None)
+
+    state.wrap(None)
+    testState(None)
+
+    state.wrap(Some(1))
+    testState(Some(1))
+
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state = new StateImpl[Int]()
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state.remove()
+    testState(None, shouldBeRemoved = true)
+
+    state.wrapTiminoutState(3)
+    testState(Some(3), shouldBeTimingOut = true)
+  }
+
+  test("trackStateByKey - basic operations with simple API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(1),
+        Seq(2, 1),
+        Seq(3, 2, 1),
+        Seq(4, 3),
+        Seq(5),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, and updated count is returned
+    val trackStateFunc = (value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      sum
+    }
+
+    testOperation[String, Int, Int](
+      inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+  }
+
+  test("trackStateByKey - basic operations with advanced API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq("aa"),
+        Seq("aa", "bb"),
+        Seq("aa", "bb", "cc"),
+        Seq("aa", "bb"),
+        Seq("aa"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, key string doubled and returned
+    val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], 
state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      Some(key * 2)
+    }
+
+    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, 
stateData)
+  }
+
+  test("trackStateByKey - type inferencing and class tags") {
+
+    // Simple track state function with value as Int, state as Double and 
emitted type as Double
+    val simpleFunc = (value: Option[Int], state: State[Double]) => {
+      0L
+    }
+
+    // Advanced track state function with key as String, value as Int, state 
as Double and
+    // emitted type as Double
+    val advancedFunc = (time: Time, key: String, value: Option[Int], state: 
State[Double]) => {
+      Some(0L)
+    }
+
+    def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
+      val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
+      assert(dstreamImpl.keyClass === classOf[String])
+      assert(dstreamImpl.valueClass === classOf[Int])
+      assert(dstreamImpl.stateClass === classOf[Double])
+      assert(dstreamImpl.emittedClass === classOf[Long])
+    }
+
+    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, 
numPartitions = 2)
+
+    // Defining StateSpec inline with trackStateByKey and simple function 
implicitly gets the types
+    val simpleFunctionStateStream1 = inputStream.trackStateByKey(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(simpleFunctionStateStream1)
+
+    // Separately defining StateSpec with simple function requires explicitly 
specifying types
+    val simpleFuncSpec = StateSpec.function[String, Int, Double, 
Long](simpleFunc)
+    val simpleFunctionStateStream2 = 
inputStream.trackStateByKey(simpleFuncSpec)
+    testTypes(simpleFunctionStateStream2)
+
+    // Separately defining StateSpec with advanced function implicitly gets 
the types
+    val advFuncSpec1 = StateSpec.function(advancedFunc)
+    val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
+    testTypes(advFunctionStateStream1)
+
+    // Defining StateSpec inline with trackStateByKey and advanced func 
implicitly gets the types
+    val advFunctionStateStream2 = inputStream.trackStateByKey(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(advFunctionStateStream2)
+
+    // Defining StateSpec inline with trackStateByKey and advanced func 
implicitly gets the types
+    val advFuncSpec2 = StateSpec.function[String, Int, Double, 
Long](advancedFunc)
+    val advFunctionStateStream3 = inputStream.trackStateByKey[Double, 
Long](advFuncSpec2)
+    testTypes(advFunctionStateStream3)
+  }
+
+  test("trackStateByKey - states as emitted records") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3)),
+        Seq(("a", 5)),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: 
State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, 
stateData)
+  }
+
+  test("trackStateByKey - initial states, with nothing emitted") {
+
+    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
+
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
+
+    val stateData =
+      Seq(
+        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
+        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
+        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: 
State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      None.asInstanceOf[Option[Int]]
+    }
+
+    val trackStateSpec = 
StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
+    testOperation(inputData, trackStateSpec, outputData, stateData)
+  }
+
+  test("trackStateByKey - state removing") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"), // a will be removed
+        Seq("a", "b", "c"), // b will be removed
+        Seq("a", "b", "c"), // a and c will be removed
+        Seq("a", "b"), // b will be removed
+        Seq("a"), // a will be removed
+        Seq()
+      )
+
+    // States that were removed
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(),
+        Seq("a"),
+        Seq("b"),
+        Seq("a", "c"),
+        Seq("b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1), ("c", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1)),
+        Seq(),
+        Seq()
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: 
State[Int]) => {
+      if (state.exists) {
+        state.remove()
+        Some(key)
+      } else {
+        state.update(value.get)
+        None
+      }
+    }
+
+    testOperation(
+      inputData, StateSpec.function(trackStateFunc).numPartitions(1), 
outputData, stateData)
+  }
+
+  test("trackStateByKey - state timing out") {
+    val inputData =
+      Seq(
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq(), // c will time out
+        Seq(), // b will time out
+        Seq("a") // a will not time out
+      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: 
State[Int]) => {
+      if (value.isDefined) {
+        state.update(1)
+      }
+      if (state.isTimingOut) {
+        Some(key)
+      } else {
+        None
+      }
+    }
+
+    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
+      inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
+
+    // b and c should be emitted once each, when they were marked as expired
+    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
+
+    // States for a, b, c should be defined at one point of time
+    assert(collectedStateSnapshots.exists {
+      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
+    })
+
+    // Finally state should be defined only for a
+    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
+  }
+
+
+  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      trackStateSpec: StateSpec[K, Int, S, T],
+      expectedOutputs: Seq[Seq[T]],
+      expectedStateSnapshots: Seq[Seq[(K, S)]]
+    ): Unit = {
+    require(expectedOutputs.size == expectedStateSnapshots.size)
+
+    val (collectedOutputs, collectedStateSnapshots) =
+      getOperationOutput(input, trackStateSpec, expectedOutputs.size)
+    assert(expectedOutputs, collectedOutputs, "outputs")
+    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
+  }
+
+  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      trackStateSpec: StateSpec[K, Int, S, T],
+      numBatches: Int
+    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
+
+    // Setup the stream computation
+    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+    val trackeStateStream = inputStream.map(x => (x, 
1)).trackStateByKey(trackStateSpec)
+    val collectedOutputs = new ArrayBuffer[Seq[T]] with 
SynchronizedBuffer[Seq[T]]
+    val outputStream = new TestOutputStream(trackeStateStream, 
collectedOutputs)
+    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with 
SynchronizedBuffer[Seq[(K, S)]]
+    val stateSnapshotStream = new TestOutputStream(
+      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
+    outputStream.register()
+    stateSnapshotStream.register()
+
+    val batchCounter = new BatchCounter(ssc)
+    ssc.start()
+
+    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+    clock.advance(batchDuration.milliseconds * numBatches)
+
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+    (collectedOutputs, collectedStateSnapshots)
+  }
+
+  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: 
String) {
+    val debugString = "\nExpected:\n" + expected.mkString("\n") +
+      "\nCollected:\n" + collected.mkString("\n")
+    assert(expected.size === collected.size,
+      s"number of collected $typ (${collected.size}) different from expected 
(${expected.size})" +
+        debugString)
+    expected.zip(collected).foreach { case (c, e) =>
+      assert(c.toSet === e.toSet,
+        s"collected $typ is different from expected $debugString"
+      )
+    }
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to