[SPARK-20883][SPARK-20376][SS] Refactored StateStore APIs and added conf to 
choose implementation

## What changes were proposed in this pull request?

A bunch of changes to the StateStore APIs and implementation.
Current state store API has a bunch of problems that causes too many transient 
objects causing memory pressure.

- `StateStore.get(): Option` forces creation of Some/None objects for every 
get. Changed this to return the row or null.
- `StateStore.iterator(): (UnsafeRow, UnsafeRow)` forces creation of new tuple 
for each record returned. Changed this to return a UnsafeRowTuple which can be 
reused across records.
- `StateStore.updates()` requires the implementation to keep track of updates, 
while this is used minimally (only by Append mode in streaming aggregations). 
Removed updates() and updated StateStoreSaveExec accordingly.
- `StateStore.filter(condition)` and `StateStore.remove(condition)` has been 
merge into a single API `getRange(start, end)` which allows a state store to do 
optimized range queries (i.e. avoid full scans). Stateful operators have been 
updated accordingly.
- Removed a lot of unnecessary row copies Each operator copied rows before 
calling StateStore.put() even if the implementation does not require it to be 
copied. It is left up to the implementation on whether to copy the row or not.

Additionally,
- Added a name to the StateStoreId so that each operator+partition can use 
multiple state stores (different names)
- Added a configuration that allows the user to specify which implementation to 
use.
- Added new metrics to understand the time taken to update keys, remove keys 
and commit all changes to the state store. These metrics will be visible on the 
plan diagram in the SQL tab of the UI.
- Refactored unit tests such that they can be reused to test any implementation 
of StateStore.

## How was this patch tested?
Old and new unit tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #18107 from tdas/SPARK-20376.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fa757ee1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fa757ee1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fa757ee1

Branch: refs/heads/master
Commit: fa757ee1d41396ad8734a3f2dd045bb09bc82a2e
Parents: 4bb6a53
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue May 30 15:33:06 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Tue May 30 15:33:06 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/internal/SQLConf.scala |  11 +
 .../streaming/FlatMapGroupsWithStateExec.scala  |  39 +-
 .../state/HDFSBackedStateStoreProvider.scala    | 218 +++-----
 .../execution/streaming/state/StateStore.scala  | 163 ++++--
 .../streaming/state/StateStoreConf.scala        |  28 +-
 .../streaming/state/StateStoreRDD.scala         |  11 +-
 .../sql/execution/streaming/state/package.scala |  11 +-
 .../execution/streaming/statefulOperators.scala | 142 +++--
 .../streaming/state/StateStoreRDDSuite.scala    |  41 +-
 .../streaming/state/StateStoreSuite.scala       | 534 +++++++++----------
 .../streaming/FlatMapGroupsWithStateSuite.scala |  40 +-
 .../spark/sql/streaming/StreamSuite.scala       |  45 ++
 12 files changed, 695 insertions(+), 588 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c5d69c2..c6f5cf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -552,6 +552,15 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val STATE_STORE_PROVIDER_CLASS =
+    buildConf("spark.sql.streaming.stateStore.providerClass")
+      .internal()
+      .doc(
+        "The class used to manage state data in stateful streaming queries. 
This class must " +
+          "be a subclass of StateStoreProvider, and must have a zero-arg 
constructor.")
+      .stringConf
+      .createOptional
+
   val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
     buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
       .internal()
@@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging {
 
   def optimizerInSetConversionThreshold: Int = 
getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)
 
+  def stateStoreProviderClass: Option[String] = 
getConf(STATE_STORE_PROVIDER_CLASS)
+
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
   def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 3ceb4cf..2aad870 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec(
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val updater = new StateStoreUpdater(store)
@@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
-        val timingOutKeys = store.filter { case (_, stateRow) =>
-          val timeoutTimestamp = getTimeoutTimestamp(stateRow)
+        val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+          val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
           timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < 
timeoutThreshold
         }
-        timingOutKeys.flatMap { case (keyRow, stateRow) =>
-          callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), 
hasTimedOut = true)
+        timingOutKeys.flatMap { rowPair =>
+          callFunctionAndUpdateState(rowPair.key, Iterator.empty, 
rowPair.value, hasTimedOut = true)
         }
       } else Iterator.empty
     }
@@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec(
      * Call the user function on a key's data, update the state store, and 
return the return data
      * iterator. Note that the store updating is lazy, that is, the store will 
be updated only
      * after the returned iterator is fully consumed.
+     *
+     * @param keyRow Row representing the key, cannot be null
+     * @param valueRowIter Iterator of values as rows, cannot be null, but can 
be empty
+     * @param prevStateRow Row representing the previous state, can be null
+     * @param hasTimedOut Whether this function is being called for a key 
timeout
      */
     private def callFunctionAndUpdateState(
         keyRow: UnsafeRow,
         valueRowIter: Iterator[InternalRow],
-        prevStateRowOption: Option[UnsafeRow],
+        prevStateRow: UnsafeRow,
         hasTimedOut: Boolean): Iterator[InternalRow] = {
 
       val keyObj = getKeyObj(keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value 
rows to objects
-      val stateObjOption = getStateObj(prevStateRowOption)
+      val stateObj = getStateObj(prevStateRow)
       val keyedState = GroupStateImpl.createForStreaming(
-        stateObjOption,
+        Option(stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
@@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec(
           numUpdatedStateRows += 1
 
         } else {
-          val previousTimeoutTimestamp = prevStateRowOption match {
-            case Some(row) => getTimeoutTimestamp(row)
-            case None => NO_TIMESTAMP
-          }
+          val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
           val stateRowToWrite = if (keyedState.hasUpdated) {
             getStateRow(keyedState.get)
           } else {
-            prevStateRowOption.orNull
+            prevStateRow
           }
 
           val hasTimeoutChanged = currentTimeoutTimestamp != 
previousTimeoutTimestamp
@@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec(
               throw new IllegalStateException("Attempting to write empty 
state")
             }
             setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
-            store.put(keyRow.copy(), stateRowToWrite.copy())
+            store.put(keyRow, stateRowToWrite)
             numUpdatedStateRows += 1
           }
         }
@@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec(
     }
 
     /** Returns the state as Java object if defined */
-    def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = {
-      stateRowOption.map(getStateObjFromRow)
+    def getStateObj(stateRow: UnsafeRow): Any = {
+      if (stateRow != null) getStateObjFromRow(stateRow) else null
     }
 
     /** Returns the row for an updated state */
     def getStateRow(obj: Any): UnsafeRow = {
+      assert(obj != null)
       getStateRowFromObj(obj)
     }
 
     /** Returns the timeout timestamp of a state row is set */
     def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else 
NO_TIMESTAMP
+      if (isTimeoutEnabled && stateRow != null) {
+        stateRow.getLong(timeoutTimestampIndex)
+      } else NO_TIMESTAMP
     }
 
     /** Set the timestamp in a state row */

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index fb2bf47..67d86da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -67,13 +67,7 @@ import org.apache.spark.util.Utils
  * to ensure re-executed RDD operations re-apply updates on the correct past 
version of the
  * store.
  */
-private[state] class HDFSBackedStateStoreProvider(
-    val id: StateStoreId,
-    keySchema: StructType,
-    valueSchema: StructType,
-    storeConf: StateStoreConf,
-    hadoopConf: Configuration
-  ) extends StateStoreProvider with Logging {
+private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider 
with Logging {
 
   // ConcurrentHashMap is used because it generates fail-safe iterators on 
filtering
   // - The iterator is weakly consistent with the map, i.e., iterator's data 
reflect the values in
@@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider(
     private val newVersion = version + 1
     private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
     private lazy val tempDeltaFileStream = 
compressStream(fs.create(tempDeltaFile, true))
-    private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]()
-
     @volatile private var state: STATE = UPDATING
     @volatile private var finalDeltaFile: Path = null
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
 
-    override def get(key: UnsafeRow): Option[UnsafeRow] = {
-      Option(mapToUpdate.get(key))
-    }
-
-    override def filter(
-        condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, 
UnsafeRow)] = {
-      mapToUpdate
-        .entrySet
-        .asScala
-        .iterator
-        .filter { entry => condition(entry.getKey, entry.getValue) }
-        .map { entry => (entry.getKey, entry.getValue) }
+    override def get(key: UnsafeRow): UnsafeRow = {
+      mapToUpdate.get(key)
     }
 
     override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
-
-      val isNewKey = !mapToUpdate.containsKey(key)
-      mapToUpdate.put(key, value)
-
-      Option(allUpdates.get(key)) match {
-        case Some(ValueAdded(_, _)) =>
-          // Value did not exist in previous version and was added already, 
keep it marked as added
-          allUpdates.put(key, ValueAdded(key, value))
-        case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) =>
-          // Value existed in previous version and updated/removed, mark it as 
updated
-          allUpdates.put(key, ValueUpdated(key, value))
-        case None =>
-          // There was no prior update, so mark this as added or updated 
according to its presence
-          // in previous version.
-          val update = if (isNewKey) ValueAdded(key, value) else 
ValueUpdated(key, value)
-          allUpdates.put(key, update)
-      }
-      writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
+      val keyCopy = key.copy()
+      val valueCopy = value.copy()
+      mapToUpdate.put(keyCopy, valueCopy)
+      writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy)
     }
 
-    /** Remove keys that match the following condition */
-    override def remove(condition: UnsafeRow => Boolean): Unit = {
+    override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
-      val entryIter = mapToUpdate.entrySet().iterator()
-      while (entryIter.hasNext) {
-        val entry = entryIter.next
-        if (condition(entry.getKey)) {
-          val value = entry.getValue
-          val key = entry.getKey
-          entryIter.remove()
-
-          Option(allUpdates.get(key)) match {
-            case Some(ValueUpdated(_, _)) | None =>
-              // Value existed in previous version and maybe was updated, mark 
removed
-              allUpdates.put(key, ValueRemoved(key, value))
-            case Some(ValueAdded(_, _)) =>
-              // Value did not exist in previous version and was added, should 
not appear in updates
-              allUpdates.remove(key)
-            case Some(ValueRemoved(_, _)) =>
-              // Remove already in update map, no need to change
-          }
-          writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-        }
+      val prevValue = mapToUpdate.remove(key)
+      if (prevValue != null) {
+        writeRemoveToDeltaFile(tempDeltaFileStream, key)
       }
     }
 
-    /** Remove a single key. */
-    override def remove(key: UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
-      if (mapToUpdate.containsKey(key)) {
-        val value = mapToUpdate.remove(key)
-        Option(allUpdates.get(key)) match {
-          case Some(ValueUpdated(_, _)) | None =>
-            // Value existed in previous version and maybe was updated, mark 
removed
-            allUpdates.put(key, ValueRemoved(key, value))
-          case Some(ValueAdded(_, _)) =>
-            // Value did not exist in previous version and was added, should 
not appear in updates
-            allUpdates.remove(key)
-          case Some(ValueRemoved(_, _)) =>
-          // Remove already in update map, no need to change
-        }
-        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-      }
+    override def getRange(
+        start: Option[UnsafeRow],
+        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+      verify(state == UPDATING, "Cannot getRange after already committed or 
aborted")
+      iterator()
     }
 
     /** Commit all the updates that have been made to the store, and return 
the new version. */
@@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider(
      * Get an iterator of all the store data.
      * This can be called only after committing all the updates made in the 
current thread.
      */
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of store data before committing or after 
aborting")
-      HDFSBackedStateStoreProvider.this.iterator(newVersion)
-    }
-
-    /**
-     * Get an iterator of all the updates made to the store in the current 
version.
-     * This can be called only after committing all the updates made in the 
current thread.
-     */
-    override def updates(): Iterator[StoreUpdate] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of updates before committing or after aborting")
-      allUpdates.values().asScala.toIterator
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      val unsafeRowPair = new UnsafeRowPair()
+      mapToUpdate.entrySet.asScala.iterator.map { entry =>
+        unsafeRowPair.withRows(entry.getKey, entry.getValue)
+      }
     }
 
     override def numKeys(): Long = mapToUpdate.size()
@@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider(
     store
   }
 
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    this.stateStoreId = stateStoreId
+    this.keySchema = keySchema
+    this.valueSchema = valueSchema
+    this.storeConf = storeConf
+    this.hadoopConf = hadoopConf
+    fs.mkdirs(baseDir)
+  }
+
+  override def id: StateStoreId = stateStoreId
+
   /** Do maintenance backing data files, including creating snapshots and 
cleaning up old files */
   override def doMaintenance(): Unit = {
     try {
@@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
+  override def close(): Unit = {
+    loadedMaps.values.foreach(_.clear())
+  }
+
   override def toString(): String = {
     s"HDFSStateStoreProvider[id = (op=${id.operatorId}, 
part=${id.partitionId}), dir = $baseDir]"
   }
 
-  /* Internal classes and methods */
+  /* Internal fields and methods */
 
-  private val loadedMaps = new mutable.HashMap[Long, MapType]
-  private val baseDir =
-    new Path(id.checkpointLocation, 
s"${id.operatorId}/${id.partitionId.toString}")
-  private val fs = baseDir.getFileSystem(hadoopConf)
-  private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new 
SparkConf)
+  @volatile private var stateStoreId: StateStoreId = _
+  @volatile private var keySchema: StructType = _
+  @volatile private var valueSchema: StructType = _
+  @volatile private var storeConf: StateStoreConf = _
+  @volatile private var hadoopConf: Configuration = _
 
-  initialize()
+  private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
+  private lazy val baseDir =
+    new Path(id.checkpointLocation, 
s"${id.operatorId}/${id.partitionId.toString}")
+  private lazy val fs = baseDir.getFileSystem(hadoopConf)
+  private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new 
SparkConf)
 
   private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
 
@@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider(
    * Get iterator of all the data of the latest version of the store.
    * Note that this will look up the files to determined the latest known 
version.
    */
-  private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = 
synchronized {
+  private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized {
     val versionsInFiles = fetchFiles().map(_.version).toSet
     val versionsLoaded = loadedMaps.keySet
     val allKnownVersions = versionsInFiles ++ versionsLoaded
+    val unsafeRowTuple = new UnsafeRowPair()
     if (allKnownVersions.nonEmpty) {
-      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x =>
-        (x.getKey, x.getValue)
+      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry 
=>
+        unsafeRowTuple.withRows(entry.getKey, entry.getValue)
       }
     } else Iterator.empty
   }
 
-  /** Get iterator of a specific version of the store */
-  private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] 
= synchronized {
-    loadMap(version).entrySet().iterator().asScala.map { x =>
-      (x.getKey, x.getValue)
-    }
-  }
-
-  /** Initialize the store provider */
-  private def initialize(): Unit = {
-    try {
-      fs.mkdirs(baseDir)
-    } catch {
-      case e: IOException =>
-        throw new IllegalStateException(
-          s"Cannot use ${id.checkpointLocation} for storing state data for 
$this: $e ", e)
-    }
-  }
-
   /** Load the required version of the map data from the backing files */
   private def loadMap(version: Long): MapType = {
     if (version <= 0) return new MapType
@@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
-  private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): 
Unit = {
-
-    def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      val valueBytes = value.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(valueBytes.size)
-      output.write(valueBytes)
-    }
-
-    def writeRemove(key: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(-1)
-    }
+  private def writeUpdateToDeltaFile(
+      output: DataOutputStream,
+      key: UnsafeRow,
+      value: UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    val valueBytes = value.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(valueBytes.size)
+    output.write(valueBytes)
+  }
 
-    update match {
-      case ValueAdded(key, value) =>
-        writeUpdate(key, value)
-      case ValueUpdated(key, value) =>
-        writeUpdate(key, value)
-      case ValueRemoved(key, value) =>
-        writeRemove(key)
-    }
+  private def writeRemoveToDeltaFile(output: DataOutputStream, key: 
UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(-1)
   }
 
   private def finalizeDeltaFile(output: DataOutputStream): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index eaa558e..29c456f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ThreadUtils
-
-
-/** Unique identifier for a [[StateStore]] */
-case class StateStoreId(checkpointLocation: String, operatorId: Long, 
partitionId: Int)
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 
 /**
- * Base trait for a versioned key-value store used for streaming aggregations
+ * Base trait for a versioned key-value store. Each instance of a `StateStore` 
represents a specific
+ * version of state data, and such instances are created through a 
[[StateStoreProvider]].
  */
 trait StateStore {
 
@@ -47,50 +44,54 @@ trait StateStore {
   /** Version of the data in this store before committing updates. */
   def version: Long
 
-  /** Get the current value of a key. */
-  def get(key: UnsafeRow): Option[UnsafeRow]
-
   /**
-   * Return an iterator of key-value pairs that satisfy a certain condition.
-   * Note that the iterator must be fail-safe towards modification to the 
store, that is,
-   * it must be based on the snapshot of store the time of this call, and any 
change made to the
-   * store while iterating through iterator should not cause the iterator to 
fail or have
-   * any affect on the values in the iterator.
+   * Get the current value of a non-null key.
+   * @return a non-null row if the key exists in the store, otherwise null.
    */
-  def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): 
Iterator[(UnsafeRow, UnsafeRow)]
+  def get(key: UnsafeRow): UnsafeRow
 
-  /** Put a new value for a key. */
+  /**
+   * Put a new value for a non-null key. Implementations must be aware that 
the UnsafeRows in
+   * the params can be reused, and must make copies of the data as needed for 
persistence.
+   */
   def put(key: UnsafeRow, value: UnsafeRow): Unit
 
   /**
-   * Remove keys that match the following condition.
+   * Remove a single non-null key.
    */
-  def remove(condition: UnsafeRow => Boolean): Unit
+  def remove(key: UnsafeRow): Unit
 
   /**
-   * Remove a single key.
+   * Get key value pairs with optional approximate `start` and `end` extents.
+   * If the State Store implementation maintains indices for the data based on 
the optional
+   * `keyIndexOrdinal` over fields `keySchema` (see 
`StateStoreProvider.init()`), then it can use
+   * `start` and `end` to make a best-effort scan over the data. Default 
implementation returns
+   * the full data scan iterator, which is correct but inefficient. Custom 
implementations must
+   * ensure that updates (puts, removes) can be made while iterating over this 
iterator.
+   *
+   * @param start UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate starting value.
+   * @param end UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate ending value.
+   * @return An iterator of key-value pairs that is guaranteed not miss any 
key between start and
+   *         end, both inclusive.
    */
-  def remove(key: UnsafeRow): Unit
+  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): 
Iterator[UnsafeRowPair] = {
+    iterator()
+  }
 
   /**
    * Commit all the updates that have been made to the store, and return the 
new version.
+   * Implementations should ensure that no more updates (puts, removes) can be 
after a commit in
+   * order to avoid incorrect usage.
    */
   def commit(): Long
 
-  /** Abort all the updates that have been made to the store. */
-  def abort(): Unit
-
   /**
-   * Iterator of store data after a set of updates have been committed.
-   * This can be called only after committing all the updates made in the 
current thread.
+   * Abort all the updates that have been made to the store. Implementations 
should ensure that
+   * no more updates (puts, removes) can be after an abort in order to avoid 
incorrect usage.
    */
-  def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
+  def abort(): Unit
 
-  /**
-   * Iterator of the updates that have been committed.
-   * This can be called only after committing all the updates made in the 
current thread.
-   */
-  def updates(): Iterator[StoreUpdate]
+  def iterator(): Iterator[UnsafeRowPair]
 
   /** Number of keys in the state store */
   def numKeys(): Long
@@ -102,28 +103,98 @@ trait StateStore {
 }
 
 
-/** Trait representing a provider of a specific version of a [[StateStore]]. */
+/**
+ * Trait representing a provider that provide [[StateStore]] instances 
representing
+ * versions of state data.
+ *
+ * The life cycle of a provider and its provide stores are as follows.
+ *
+ * - A StateStoreProvider is created in a executor for each unique 
[[StateStoreId]] when
+ *   the first batch of a streaming query is executed on the executor. All 
subsequent batches reuse
+ *   this provider instance until the query is stopped.
+ *
+ * - Every batch of streaming data request a specific version of the state 
data by invoking
+ *   `getStore(version)` which returns an instance of [[StateStore]] through 
which the required
+ *   version of the data can be accessed. It is the responsible of the 
provider to populate
+ *   this store with context information like the schema of keys and values, 
etc.
+ *
+ * - After the streaming query is stopped, the created provider instances are 
lazily disposed off.
+ */
 trait StateStoreProvider {
 
-  /** Get the store with the existing version. */
+  /**
+   * Initialize the provide with more contextual information from the SQL 
operator.
+   * This method will be called first after creating an instance of the 
StateStoreProvider by
+   * reflection.
+   *
+   * @param stateStoreId Id of the versioned StateStores that this provider 
will generate
+   * @param keySchema Schema of keys to be stored
+   * @param valueSchema Schema of value to be stored
+   * @param keyIndexOrdinal Optional column (represent as the ordinal of the 
field in keySchema) by
+   *                        which the StateStore implementation could index 
the data.
+   * @param storeConfs Configurations used by the StateStores
+   * @param hadoopConf Hadoop configuration that could be used by StateStore 
to save state data
+   */
+  def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyIndexOrdinal: Option[Int], // for sorting the data by their keys
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit
+
+  /**
+   * Return the id of the StateStores this provider will generate.
+   * Should be the same as the one passed in init().
+   */
+  def id: StateStoreId
+
+  /** Called when the provider instance is unloaded from the executor */
+  def close(): Unit
+
+  /** Return an instance of [[StateStore]] representing state data of the 
given version */
   def getStore(version: Long): StateStore
 
-  /** Optional method for providers to allow for background maintenance */
+  /** Optional method for providers to allow for background maintenance (e.g. 
compactions) */
   def doMaintenance(): Unit = { }
 }
 
-
-/** Trait representing updates made to a [[StateStore]]. */
-sealed trait StoreUpdate {
-  def key: UnsafeRow
-  def value: UnsafeRow
+object StateStoreProvider {
+  /**
+   * Return a provider instance of the given provider class.
+   * The instance will be already initialized.
+   */
+  def instantiate(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): StateStoreProvider = {
+    val providerClass = storeConf.providerClass.map(Utils.classForName)
+        .getOrElse(classOf[HDFSBackedStateStoreProvider])
+    val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
+    provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, 
storeConf, hadoopConf)
+    provider
+  }
 }
 
-case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
 
-case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Unique identifier for a bunch of keyed state data. */
+case class StateStoreId(
+    checkpointLocation: String,
+    operatorId: Long,
+    partitionId: Int,
+    name: String = "")
 
-case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Mutable, and reusable class for representing a pair of UnsafeRows. */
+class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
+  def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = {
+    this.key = key
+    this.value = value
+    this
+  }
+}
 
 
 /**
@@ -185,6 +256,7 @@ object StateStore extends Logging {
       storeId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
+      indexOrdinal: Option[Int],
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
@@ -193,7 +265,9 @@ object StateStore extends Logging {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
         storeId,
-        new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, 
storeConf, hadoopConf))
+        StateStoreProvider.instantiate(
+          storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+      )
       reportActiveStoreInstance(storeId)
       provider
     }
@@ -202,7 +276,7 @@ object StateStore extends Logging {
 
   /** Unload a state store provider */
   def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
-    loadedProviders.remove(storeId)
+    loadedProviders.remove(storeId).foreach(_.close())
   }
 
   /** Whether a state store provider is loaded or not */
@@ -216,6 +290,7 @@ object StateStore extends Logging {
 
   /** Unload and stop all state store providers */
   def stop(): Unit = loadedProviders.synchronized {
+    loadedProviders.keySet.foreach { key => unload(key) }
     loadedProviders.clear()
     _coordRef = null
     if (maintenanceTask != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index acfaa8e..bab297c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.internal.SQLConf
 
 /** A class that contains configuration parameters for [[StateStore]]s. */
-private[streaming] class StateStoreConf(@transient private val conf: SQLConf) 
extends Serializable {
+class StateStoreConf(@transient private val sqlConf: SQLConf)
+  extends Serializable {
 
   def this() = this(new SQLConf)
 
-  val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot
-
-  val minVersionsToRetain = conf.minBatchesToRetain
+  /**
+   * Minimum number of delta files in a chain after which HDFSBackedStateStore 
will
+   * consider generating a snapshot.
+   */
+  val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot
+
+  /** Minimum versions a State Store implementation should retain to allow 
rollbacks */
+  val minVersionsToRetain: Int = sqlConf.minBatchesToRetain
+
+  /**
+   * Optional fully qualified name of the subclass of [[StateStoreProvider]]
+   * managing state data. That is, the implementation of the State Store to 
use.
+   */
+  val providerClass: Option[String] = sqlConf.stateStoreProviderClass
+
+  /**
+   * Additional configurations related to state store. This will capture all 
configs in
+   * SQLConf that start with `spark.sql.streaming.stateStore.` */
+  val confs: Map[String, String] =
+    
sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
 }
 
-private[streaming] object StateStoreConf {
+object StateStoreConf {
   val empty = new StateStoreConf()
 
   def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf)

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index e16dda8..b744c25 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
     checkpointLocation: String,
     operatorId: Long,
+    storeName: String,
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
+    indexOrdinal: Option[Int],
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
   extends RDD[U](dataRDD) {
@@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
   private val storeConf = new StateStoreConf(sessionState.conf)
 
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so 
broadcast it
-  private val confBroadcast = dataRDD.context.broadcast(
+  private val hadoopConfBroadcast = dataRDD.context.broadcast(
     new SerializableConfiguration(sessionState.newHadoopConf()))
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
   override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, 
partition.index, storeName)
     storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
   }
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = 
{
     var store: StateStore = null
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, 
partition.index, storeName)
     store = StateStore.get(
-      storeId, keySchema, valueSchema, storeVersion, storeConf, 
confBroadcast.value.value)
+      storeId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 589042a..228fe86 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -34,17 +34,21 @@ package object state {
         sqlContext: SQLContext,
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
-        valueSchema: StructType)(
+        valueSchema: StructType,
+        indexOrdinal: Option[Int])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): 
StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -54,9 +58,11 @@ package object state {
     private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
         valueSchema: StructType,
+        indexOrdinal: Option[Int],
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): 
StateStoreRDD[T, U] = {
@@ -69,14 +75,17 @@ package object state {
         })
         cleanedF(store, iter)
       }
+
       new StateStoreRDD(
         dataRDD,
         wrappedF,
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sessionState,
         storeCoordinator)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 8dbda29..3e57f3f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -17,21 +17,22 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.concurrent.TimeUnit._
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, 
Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
LogicalGroupState, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types._
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, NextIterator}
 
 
 /** Used to identify the state store for a given operator. */
@@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator {
 }
 
 /** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator {
+trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
     "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
updated state rows"))
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
updated state rows"),
+    "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total 
time to update rows"),
+    "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total 
time to remove rows"),
+    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
commit changes")
+  )
+
+  /** Records the duration of running `body` for the next query progress 
update. */
+  protected def timeTakenMs(body: => Unit): Long = {
+    val startTime = System.nanoTime()
+    val result = body
+    val endTime = System.nanoTime()
+    math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
+  }
 }
 
 /** An operator that supports watermark. */
@@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode {
   /** Predicate based on the child output that matches data older than the 
watermark. */
   lazy val watermarkPredicateForData: Option[Predicate] =
     watermarkExpression.map(newPredicate(_, child.output))
+
+  protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
+    if (watermarkPredicateForKeys.nonEmpty) {
+      store.getRange(None, None).foreach { rowPair =>
+        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+          store.remove(rowPair.key)
+        }
+      }
+    }
+  }
 }
 
 /**
@@ -126,9 +150,11 @@ case class StateStoreRestoreExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       operatorId = getStateId.operatorId,
+      storeName = "default",
       storeVersion = getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
@@ -136,7 +162,7 @@ case class StateStoreRestoreExec(
           val key = getKey(row)
           val savedState = store.get(key)
           numOutputRows += 1
-          row +: savedState.toSeq
+          row +: Option(savedState).toSeq
         }
     }
   }
@@ -165,54 +191,88 @@ case class StateStoreSaveExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
         val numOutputRows = longMetric("numOutputRows")
         val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+        val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+        val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+        val commitTimeMs = longMetric("commitTimeMs")
 
         outputMode match {
           // Update and output all rows in the StateStore.
           case Some(Complete) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              while (iter.hasNext) {
+                val row = iter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
+            }
+            allRemovalsTimeMs += 0
+            commitTimeMs += timeTakenMs {
+              store.commit()
             }
-            store.commit()
             numTotalStateRows += store.numKeys()
-            store.iterator().map { case (k, v) =>
+            store.iterator().map { rowPair =>
               numOutputRows += 1
-              v.asInstanceOf[InternalRow]
+              rowPair.value
             }
 
           // Update and output only rows being evicted from the StateStore
+          // Assumption: watermark predicates must be non-empty if append mode 
is allowed
           case Some(Append) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              val filteredIter = iter.filter(row => 
!watermarkPredicateForData.get.eval(row))
+              while (filteredIter.hasNext) {
+                val row = filteredIter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
             }
 
-            // Assumption: Append mode can be done only when watermark has 
been specified
-            store.remove(watermarkPredicateForKeys.get.eval _)
-            store.commit()
+            val removalStartTimeNs = System.nanoTime
+            val rangeIter = store.getRange(None, None)
+
+            new NextIterator[InternalRow] {
+              override protected def getNext(): InternalRow = {
+                var removedValueRow: InternalRow = null
+                while(rangeIter.hasNext && removedValueRow == null) {
+                  val rowPair = rangeIter.next()
+                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+                    store.remove(rowPair.key)
+                    removedValueRow = rowPair.value
+                  }
+                }
+                if (removedValueRow == null) {
+                  finished = true
+                  null
+                } else {
+                  removedValueRow
+                }
+              }
 
-            numTotalStateRows += store.numKeys()
-            store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed 
=>
-              numOutputRows += 1
-              removed.value.asInstanceOf[InternalRow]
+              override protected def close(): Unit = {
+                allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
removalStartTimeNs)
+                commitTimeMs += timeTakenMs { store.commit() }
+                numTotalStateRows += store.numKeys()
+              }
             }
 
           // Update and output modified rows from the StateStore.
           case Some(Update) =>
 
+            val updatesStartTimeNs = System.nanoTime
+
             new Iterator[InternalRow] {
 
               // Filter late date using watermark if specified
@@ -223,11 +283,11 @@ case class StateStoreSaveExec(
 
               override def hasNext: Boolean = {
                 if (!baseIterator.hasNext) {
+                  allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
+
                   // Remove old aggregates if watermark specified
-                  if (watermarkPredicateForKeys.nonEmpty) {
-                    store.remove(watermarkPredicateForKeys.get.eval _)
-                  }
-                  store.commit()
+                  allRemovalsTimeMs += timeTakenMs { 
removeKeysOlderThanWatermark(store) }
+                  commitTimeMs += timeTakenMs { store.commit() }
                   numTotalStateRows += store.numKeys()
                   false
                 } else {
@@ -238,7 +298,7 @@ case class StateStoreSaveExec(
               override def next(): InternalRow = {
                 val row = baseIterator.next().asInstanceOf[UnsafeRow]
                 val key = getKey(row)
-                store.put(key.copy(), row.copy())
+                store.put(key, row)
                 numOutputRows += 1
                 numUpdatedStateRows += 1
                 row
@@ -273,27 +333,34 @@ case class StreamingDeduplicateExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
       val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
       val numOutputRows = longMetric("numOutputRows")
       val numTotalStateRows = longMetric("numTotalStateRows")
       val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+      val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+      val commitTimeMs = longMetric("commitTimeMs")
 
       val baseIterator = watermarkPredicateForData match {
         case Some(predicate) => iter.filter(row => !predicate.eval(row))
         case None => iter
       }
 
+      val updatesStartTimeNs = System.nanoTime
+
       val result = baseIterator.filter { r =>
         val row = r.asInstanceOf[UnsafeRow]
         val key = getKey(row)
         val value = store.get(key)
-        if (value.isEmpty) {
-          store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
+        if (value == null) {
+          store.put(key, StreamingDeduplicateExec.EMPTY_ROW)
           numUpdatedStateRows += 1
           numOutputRows += 1
           true
@@ -304,8 +371,9 @@ case class StreamingDeduplicateExec(
       }
 
       CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
-        watermarkPredicateForKeys.foreach(f => store.remove(f.eval _))
-        store.commit()
+        allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
+        allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) 
}
+        commitTimeMs += timeTakenMs { store.commit() }
         numTotalStateRows += store.numKeys()
       })
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index bd197be..4a1a089 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with 
BeforeAndAfterAll {
 
+  import StateStoreTestsHelper._
+
   private val sparkConf = new 
SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
-  private var tempDir = 
Files.createTempDirectory("StateStoreRDDSuite").toString
+  private val tempDir = 
Files.createTempDirectory("StateStoreRDDSuite").toString
   private val keySchema = StructType(Seq(StructField("key", StringType, true)))
   private val valueSchema = StructType(Seq(StructField("value", IntegerType, 
true)))
 
-  import StateStoreSuite._
-
   after {
     StateStore.stop()
   }
@@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
       val opId = 0
       val rdd1 =
         makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-            spark.sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(
+            spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
             increment)
       assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq("a", 
"c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(increment)
+        spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(
+        increment)
       assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
       // Make sure the previous RDD still has the same data.
@@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
         storeVersion: Int): RDD[(String, Int)] = {
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion, keySchema, 
valueSchema)(increment)
+        sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, 
None)(increment)
     }
 
     // Generate RDDs and state store data
@@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
       def iteratorOfPuts(store: StateStore, iter: Iterator[String]): 
Iterator[(String, Int)] = {
         val resIterator = iter.map { s =>
           val key = stringToRow(s)
-          val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+          val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
           val newValue = oldValue + 1
           store.put(key, intToRow(newValue))
           (s, newValue)
@@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
           iter: Iterator[String]): Iterator[(String, Option[Int])] = {
         iter.map { s =>
           val key = stringToRow(s)
-          val value = store.get(key).map(rowToInt)
+          val value = Option(store.get(key)).map(rowToInt)
           (s, value)
         }
       }
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(iteratorOfGets)
+        spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" 
-> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(iteratorOfPuts)
+        sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+        iteratorOfPuts)
       assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(iteratorOfGets)
+        sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> 
Some(1), "c" -> None))
     }
   }
@@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
       withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { 
spark =>
         implicit val sqlContext = spark.sqlContext
         val coordinatorRef = sqlContext.streams.stateStoreCoordinator
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), 
"host1", "exec1")
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), 
"host2", "exec2")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, 
"name"), "host1", "exec1")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, 
"name"), "host2", "exec2")
 
         assert(
-          coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+          coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+          increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
         val path = Utils.createDirectory(tempDir, 
Random.nextString(10)).toString
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, 
valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq("a", 
"c")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 1, keySchema, 
valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
       val key = stringToRow(s)
-      val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+      val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
       store.put(key, intToRow(oldValue + 1))
     }
     store.commit()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to