Repository: spark
Updated Branches:
  refs/heads/master b8a743b6a -> fe24634d1


[SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reload 
StateStoreProviders when query is restarted

## What changes were proposed in this pull request?
StateStoreProvider instances are loaded on-demand in a executor when a query is 
started. When a query is restarted, the loaded provider instance will get 
reused. Now, there is a non-trivial chance, that the task of the previous query 
run is still running, while the tasks of the restarted run has started. So for 
a stateful partition, there may be two concurrent tasks related to the same 
stateful partition, and there for using the same provider instance. This can 
lead to inconsistent results and possibly random failures, as state store 
implementations are not designed to be thread-safe.

To fix this, I have introduced a `StateStoreProviderId`, that unique identifies 
a provider loaded in an executor. It has the query run id in it, thus making 
sure that restarted queries will force the executor to load a new provider 
instance, thus avoiding two concurrent tasks (from two different runs) from 
reusing the same provider instance.

Additional minor bug fixes
- All state stores related to query run is marked as deactivated in the 
`StateStoreCoordinator` so that the executors can unload them and clear 
resources.
- Moved the code that determined the checkpoint directory of a state store from 
implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific 
code (StateStoreId), so that implementation do not accidentally get it wrong.
  - Also added store name to the path, to support multiple stores per sql 
operator partition.

*Note:* This change does not address the scenario where two tasks of the same 
run (e.g. speculative tasks) are concurrently running in the same executor. The 
chance of this very small, because ideally speculative tasks should never run 
in the same executor.

## How was this patch tested?
Existing unit tests + new unit test.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #18355 from tdas/SPARK-21145.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fe24634d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fe24634d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fe24634d

Branch: refs/heads/master
Commit: fe24634d14bc0973ca38222db2f58eafbf0c890d
Parents: b8a743b
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Fri Jun 23 00:43:21 2017 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Fri Jun 23 00:43:21 2017 -0700

----------------------------------------------------------------------
 .../sql/execution/aggregate/AggUtils.scala      |  2 +-
 .../spark/sql/execution/command/commands.scala  |  5 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  |  7 +-
 .../streaming/IncrementalExecution.scala        | 27 +++---
 .../execution/streaming/StreamExecution.scala   |  1 +
 .../state/HDFSBackedStateStoreProvider.scala    | 16 ++--
 .../execution/streaming/state/StateStore.scala  | 91 ++++++++++++++-----
 .../streaming/state/StateStoreCoordinator.scala | 41 +++++----
 .../streaming/state/StateStoreRDD.scala         | 21 +++--
 .../sql/execution/streaming/state/package.scala | 25 ++----
 .../execution/streaming/statefulOperators.scala | 38 ++++----
 .../sql/streaming/StreamingQueryManager.scala   |  1 +
 .../state/StateStoreCoordinatorSuite.scala      | 61 +++++++++++--
 .../streaming/state/StateStoreRDDSuite.scala    | 51 ++++++-----
 .../streaming/state/StateStoreSuite.scala       | 93 +++++++++++++++-----
 .../spark/sql/streaming/StreamSuite.scala       |  2 +-
 .../apache/spark/sql/streaming/StreamTest.scala | 13 ++-
 17 files changed, 329 insertions(+), 166 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index aa789af..12f8cff 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -311,7 +311,7 @@ object AggUtils {
     val saved =
       StateStoreSaveExec(
         groupingAttributes,
-        stateId = None,
+        stateInfo = None,
         outputMode = None,
         eventTimeWatermark = None,
         partialMerged2)

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 2d82fcf..81bc93e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.command
 
+import java.util.UUID
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -117,7 +119,8 @@ case class ExplainCommand(
         // This is used only by explaining `Dataset/DataFrame` created by 
`spark.readStream`, so the
         // output mode does not matter since there is no `Sink`.
         new IncrementalExecution(
-          sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, 
OffsetSeqMetadata(0, 0))
+          sparkSession, logicalPlan, OutputMode.Append(), "<unknown>",
+          UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
       } else {
         sparkSession.sessionState.executePlan(logicalPlan)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 2aad870..9dcac33 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     outputObjAttr: Attribute,
-    stateId: Option[OperatorStateId],
+    stateInfo: Option[StatefulOperatorStateInfo],
     stateEncoder: ExpressionEncoder[Any],
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
@@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec(
     }
 
     child.execute().mapPartitionsWithStateStore[InternalRow](
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
       indexOrdinal = None,

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 622e049..ab89dc6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark.internal.Logging
@@ -36,6 +37,7 @@ class IncrementalExecution(
     logicalPlan: LogicalPlan,
     val outputMode: OutputMode,
     val checkpointLocation: String,
+    val runId: UUID,
     val currentBatchId: Long,
     offsetSeqMetadata: OffsetSeqMetadata)
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
@@ -69,7 +71,13 @@ class IncrementalExecution(
    * Records the current id for a given stateful operator in the query plan as 
the `state`
    * preparation walks the query plan.
    */
-  private val operatorId = new AtomicInteger(0)
+  private val statefulOperatorId = new AtomicInteger(0)
+
+  /** Get the state info of the next stateful operator */
+  private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
+    StatefulOperatorStateInfo(
+      checkpointLocation, runId, statefulOperatorId.getAndIncrement(), 
currentBatchId)
+  }
 
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
@@ -78,35 +86,28 @@ class IncrementalExecution(
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
                StateStoreRestoreExec(keys2, None, child))) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), 
currentBatchId)
-
+        val aggStateInfo = nextStatefulOperationStateInfo
         StateStoreSaveExec(
           keys,
-          Some(stateId),
+          Some(aggStateInfo),
           Some(outputMode),
           Some(offsetSeqMetadata.batchWatermarkMs),
           agg.withNewChildren(
             StateStoreRestoreExec(
               keys,
-              Some(stateId),
+              Some(aggStateInfo),
               child) :: Nil))
 
       case StreamingDeduplicateExec(keys, child, None, None) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), 
currentBatchId)
-
         StreamingDeduplicateExec(
           keys,
           child,
-          Some(stateId),
+          Some(nextStatefulOperationStateInfo),
           Some(offsetSeqMetadata.batchWatermarkMs))
 
       case m: FlatMapGroupsWithStateExec =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), 
currentBatchId)
         m.copy(
-          stateId = Some(stateId),
+          stateInfo = Some(nextStatefulOperationStateInfo),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
           eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 74f0f50..06bdec8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -652,6 +652,7 @@ class StreamExecution(
         triggerLogicalPlan,
         outputMode,
         checkpointFile("state"),
+        runId,
         currentBatchId,
         offsetSeqMetadata)
       lastExecution.executedPlan // Force the lazy generation of execution plan

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 67d86da..bae7a15 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
     @volatile private var state: STATE = UPDATING
     @volatile private var finalDeltaFile: Path = null
 
-    override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
+    override def id: StateStoreId = 
HDFSBackedStateStoreProvider.this.stateStoreId
 
     override def get(key: UnsafeRow): UnsafeRow = {
       mapToUpdate.get(key)
@@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
     /**
      * Whether all updates have been committed
      */
-    override private[streaming] def hasCommitted: Boolean = {
+    override def hasCommitted: Boolean = {
       state == COMMITTED
     }
 
@@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
       indexOrdinal: Option[Int], // for sorting the data
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
-    this.stateStoreId = stateStoreId
+    this.stateStoreId_ = stateStoreId
     this.keySchema = keySchema
     this.valueSchema = valueSchema
     this.storeConf = storeConf
@@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
     fs.mkdirs(baseDir)
   }
 
-  override def id: StateStoreId = stateStoreId
+  override def stateStoreId: StateStoreId = stateStoreId_
 
   /** Do maintenance backing data files, including creating snapshots and 
cleaning up old files */
   override def doMaintenance(): Unit = {
@@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
   }
 
   override def toString(): String = {
-    s"HDFSStateStoreProvider[id = (op=${id.operatorId}, 
part=${id.partitionId}), dir = $baseDir]"
+    s"HDFSStateStoreProvider[" +
+      s"id = 
(op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = 
$baseDir]"
   }
 
   /* Internal fields and methods */
 
-  @volatile private var stateStoreId: StateStoreId = _
+  @volatile private var stateStoreId_ : StateStoreId = _
   @volatile private var keySchema: StructType = _
   @volatile private var valueSchema: StructType = _
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
 
   private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
-  private lazy val baseDir =
-    new Path(id.checkpointLocation, 
s"${id.operatorId}/${id.partitionId.toString}")
+  private lazy val baseDir = stateStoreId.storeCheckpointLocation()
   private lazy val fs = baseDir.getFileSystem(hadoopConf)
   private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new 
SparkConf)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 29c456f..a94ff8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
 import java.util.concurrent.{ScheduledFuture, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
 
@@ -24,14 +25,14 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkContext, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{ThreadUtils, Utils}
 
-
 /**
  * Base trait for a versioned key-value store. Each instance of a `StateStore` 
represents a specific
  * version of state data, and such instances are created through a 
[[StateStoreProvider]].
@@ -99,7 +100,7 @@ trait StateStore {
   /**
    * Whether all updates have been committed
    */
-  private[streaming] def hasCommitted: Boolean
+  def hasCommitted: Boolean
 }
 
 
@@ -147,7 +148,7 @@ trait StateStoreProvider {
    * Return the id of the StateStores this provider will generate.
    * Should be the same as the one passed in init().
    */
-  def id: StateStoreId
+  def stateStoreId: StateStoreId
 
   /** Called when the provider instance is unloaded from the executor */
   def close(): Unit
@@ -179,13 +180,46 @@ object StateStoreProvider {
   }
 }
 
+/**
+ * Unique identifier for a provider, used to identify when providers can be 
reused.
+ * Note that `queryRunId` is used uniquely identify a provider, so that the 
same provider
+ * instance is not reused across query restarts.
+ */
+case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID)
 
-/** Unique identifier for a bunch of keyed state data. */
+/**
+ * Unique identifier for a bunch of keyed state data.
+ * @param checkpointRootLocation Root directory where all the state data of a 
query is stored
+ * @param operatorId Unique id of a stateful operator
+ * @param partitionId Index of the partition of an operators state data
+ * @param storeName Optional, name of the store. Each partition can optionally 
use multiple state
+ *                  stores, but they have to be identified by distinct names.
+ */
 case class StateStoreId(
-    checkpointLocation: String,
+    checkpointRootLocation: String,
     operatorId: Long,
     partitionId: Int,
-    name: String = "")
+    storeName: String = StateStoreId.DEFAULT_STORE_NAME) {
+
+  /**
+   * Checkpoint directory to be used by a single state store, identified 
uniquely by the tuple
+   * (operatorId, partitionId, storeName). All implementations of 
[[StateStoreProvider]] should
+   * use this path for saving state data, as this ensures that distinct stores 
will write to
+   * different locations.
+   */
+  def storeCheckpointLocation(): Path = {
+    if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
+      // For reading state store data that was generated before store names 
were used (Spark <= 2.2)
+      new Path(checkpointRootLocation, s"$operatorId/$partitionId")
+    } else {
+      new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName")
+    }
+  }
+}
+
+object StateStoreId {
+  val DEFAULT_STORE_NAME = "default"
+}
 
 /** Mutable, and reusable class for representing a pair of UnsafeRows. */
 class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
@@ -211,7 +245,7 @@ object StateStore extends Logging {
   val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
 
   @GuardedBy("loadedProviders")
-  private val loadedProviders = new mutable.HashMap[StateStoreId, 
StateStoreProvider]()
+  private val loadedProviders = new mutable.HashMap[StateStoreProviderId, 
StateStoreProvider]()
 
   /**
    * Runs the `task` periodically and automatically cancels it if there is an 
exception. `onError`
@@ -253,7 +287,7 @@ object StateStore extends Logging {
 
   /** Get or create a store associated with the id. */
   def get(
-      storeId: StateStoreId,
+      storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
       indexOrdinal: Option[Int],
@@ -264,24 +298,24 @@ object StateStore extends Logging {
     val storeProvider = loadedProviders.synchronized {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
-        storeId,
+        storeProviderId,
         StateStoreProvider.instantiate(
-          storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+          storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, 
storeConf, hadoopConf)
       )
-      reportActiveStoreInstance(storeId)
+      reportActiveStoreInstance(storeProviderId)
       provider
     }
     storeProvider.getStore(version)
   }
 
   /** Unload a state store provider */
-  def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
-    loadedProviders.remove(storeId).foreach(_.close())
+  def unload(storeProviderId: StateStoreProviderId): Unit = 
loadedProviders.synchronized {
+    loadedProviders.remove(storeProviderId).foreach(_.close())
   }
 
   /** Whether a state store provider is loaded or not */
-  def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
-    loadedProviders.contains(storeId)
+  def isLoaded(storeProviderId: StateStoreProviderId): Boolean = 
loadedProviders.synchronized {
+    loadedProviders.contains(storeProviderId)
   }
 
   def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
@@ -340,21 +374,21 @@ object StateStore extends Logging {
     }
   }
 
-  private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
+  private def reportActiveStoreInstance(storeProviderId: 
StateStoreProviderId): Unit = {
     if (SparkEnv.get != null) {
       val host = SparkEnv.get.blockManager.blockManagerId.host
       val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
-      coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
-      logDebug(s"Reported that the loaded instance $storeId is active")
+      coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, 
executorId))
+      logInfo(s"Reported that the loaded instance $storeProviderId is active")
     }
   }
 
-  private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
+  private def verifyIfStoreInstanceActive(storeProviderId: 
StateStoreProviderId): Boolean = {
     if (SparkEnv.get != null) {
       val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
       val verified =
-        coordinatorRef.map(_.verifyIfInstanceActive(storeId, 
executorId)).getOrElse(false)
-      logDebug(s"Verified whether the loaded instance $storeId is active: 
$verified")
+        coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, 
executorId)).getOrElse(false)
+      logDebug(s"Verified whether the loaded instance $storeProviderId is 
active: $verified")
       verified
     } else {
       false
@@ -364,12 +398,21 @@ object StateStore extends Logging {
   private def coordinatorRef: Option[StateStoreCoordinatorRef] = 
loadedProviders.synchronized {
     val env = SparkEnv.get
     if (env != null) {
-      if (_coordRef == null) {
+      logInfo("Env is not null")
+      val isDriver =
+        env.executorId == SparkContext.DRIVER_IDENTIFIER ||
+          env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
+      // If running locally, then the coordinator reference in _coordRef may 
be have become inactive
+      // as SparkContext + SparkEnv may have been restarted. Hence, when 
running in driver,
+      // always recreate the reference.
+      if (isDriver || _coordRef == null) {
+        logInfo("Getting StateStoreCoordinatorRef")
         _coordRef = StateStoreCoordinatorRef.forExecutor(env)
       }
-      logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
+      logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
       Some(_coordRef)
     } else {
+      logInfo("Env is null")
       _coordRef = null
       None
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index d0f8188..3884f5e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import scala.collection.mutable
 
 import org.apache.spark.SparkEnv
@@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils
 private sealed trait StateStoreCoordinatorMessage extends Serializable
 
 /** Classes representing messages */
-private case class ReportActiveInstance(storeId: StateStoreId, host: String, 
executorId: String)
+private case class ReportActiveInstance(
+    storeId: StateStoreProviderId,
+    host: String,
+    executorId: String)
   extends StateStoreCoordinatorMessage
 
-private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: 
String)
+private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, 
executorId: String)
   extends StateStoreCoordinatorMessage
 
-private case class GetLocation(storeId: StateStoreId)
+private case class GetLocation(storeId: StateStoreProviderId)
   extends StateStoreCoordinatorMessage
 
-private case class DeactivateInstances(checkpointLocation: String)
+private case class DeactivateInstances(runId: UUID)
   extends StateStoreCoordinatorMessage
 
 private object StopCoordinator
@@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging {
 class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
 
   private[state] def reportActiveInstance(
-      storeId: StateStoreId,
+      stateStoreProviderId: StateStoreProviderId,
       host: String,
       executorId: String): Unit = {
-    rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId))
+    rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, 
executorId))
   }
 
   /** Verify whether the given executor has the active instance of a state 
store */
-  private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: 
String): Boolean = {
-    rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, 
executorId))
+  private[state] def verifyIfInstanceActive(
+      stateStoreProviderId: StateStoreProviderId,
+      executorId: String): Boolean = {
+    
rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, 
executorId))
   }
 
   /** Get the location of the state store */
-  private[state] def getLocation(storeId: StateStoreId): Option[String] = {
-    rpcEndpointRef.askSync[Option[String]](GetLocation(storeId))
+  private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): 
Option[String] = {
+    rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId))
   }
 
-  /** Deactivate instances related to a set of operator */
-  private[state] def deactivateInstances(storeRootLocation: String): Unit = {
-    rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation))
+  /** Deactivate instances related to a query */
+  private[sql] def deactivateInstances(runId: UUID): Unit = {
+    rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId))
   }
 
   private[state] def stop(): Unit = {
@@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: 
RpcEndpointRef) {
  */
 private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
     extends ThreadSafeRpcEndpoint with Logging {
-  private val instances = new mutable.HashMap[StateStoreId, 
ExecutorCacheTaskLocation]
+  private val instances = new mutable.HashMap[StateStoreProviderId, 
ExecutorCacheTaskLocation]
 
   override def receive: PartialFunction[Any, Unit] = {
     case ReportActiveInstance(id, host, executorId) =>
@@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
       logDebug(s"Got location of the state store $id: $executorId")
       context.reply(executorId)
 
-    case DeactivateInstances(checkpointLocation) =>
+    case DeactivateInstances(runId) =>
       val storeIdsToRemove =
-        instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq
+        instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
-      logDebug(s"Deactivating instances related to checkpoint location 
$checkpointLocation: " +
+      logDebug(s"Deactivating instances related to checkpoint location $runId: 
" +
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index b744c25..01d8e75 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import scala.reflect.ClassTag
 
 import org.apache.spark.{Partition, TaskContext}
@@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     dataRDD: RDD[T],
     storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
     checkpointLocation: String,
+    queryRunId: UUID,
     operatorId: Long,
-    storeName: String,
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
@@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
+  /**
+   * Set the preferred location of each partition using the executor that has 
the related
+   * [[StateStoreProvider]] already loaded.
+   */
   override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val storeId = StateStoreId(checkpointLocation, operatorId, 
partition.index, storeName)
-    storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
+    val stateStoreProviderId = StateStoreProviderId(
+      StateStoreId(checkpointLocation, operatorId, partition.index),
+      queryRunId)
+    storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq
   }
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = 
{
     var store: StateStore = null
-    val storeId = StateStoreId(checkpointLocation, operatorId, 
partition.index, storeName)
+    val storeProviderId = StateStoreProviderId(
+      StateStoreId(checkpointLocation, operatorId, partition.index),
+      queryRunId)
+
     store = StateStore.get(
-      storeId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 228fe86..a0086e2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
+
 import scala.reflect.ClassTag
 
 import org.apache.spark.TaskContext
@@ -32,20 +34,14 @@ package object state {
     /** Map each partition of an RDD along with data in a [[StateStore]]. */
     def mapPartitionsWithStateStore[U: ClassTag](
         sqlContext: SQLContext,
-        checkpointLocation: String,
-        operatorId: Long,
-        storeName: String,
-        storeVersion: Long,
+        stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
         indexOrdinal: Option[Int])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): 
StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
-        checkpointLocation,
-        operatorId,
-        storeName,
-        storeVersion,
+        stateInfo,
         keySchema,
         valueSchema,
         indexOrdinal,
@@ -56,10 +52,7 @@ package object state {
 
     /** Map each partition of an RDD along with data in a [[StateStore]]. */
     private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
-        checkpointLocation: String,
-        operatorId: Long,
-        storeName: String,
-        storeVersion: Long,
+        stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
         indexOrdinal: Option[Int],
@@ -79,10 +72,10 @@ package object state {
       new StateStoreRDD(
         dataRDD,
         wrappedF,
-        checkpointLocation,
-        operatorId,
-        storeName,
-        storeVersion,
+        stateInfo.checkpointLocation,
+        stateInfo.queryRunId,
+        stateInfo.operatorId,
+        stateInfo.storeVersion,
         keySchema,
         valueSchema,
         indexOrdinal,

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3e57f3f..c572246 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit._
 
 import org.apache.spark.rdd.RDD
@@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, 
NextIterator}
 
 
 /** Used to identify the state store for a given operator. */
-case class OperatorStateId(
+case class StatefulOperatorStateInfo(
     checkpointLocation: String,
+    queryRunId: UUID,
     operatorId: Long,
-    batchId: Long)
+    storeVersion: Long)
 
 /**
- * An operator that reads or writes state from the [[StateStore]].  The 
[[OperatorStateId]] should
- * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ * An operator that reads or writes state from the [[StateStore]].
+ * The [[StatefulOperatorStateInfo]] should be filled in by 
`prepareForExecution` in
+ * [[IncrementalExecution]].
  */
 trait StatefulOperator extends SparkPlan {
-  def stateId: Option[OperatorStateId]
+  def stateInfo: Option[StatefulOperatorStateInfo]
 
-  protected def getStateId: OperatorStateId = attachTree(this) {
-    stateId.getOrElse {
+  protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) {
+    stateInfo.getOrElse {
       throw new IllegalStateException("State location not present for 
execution")
     }
   }
@@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode {
  */
 case class StateStoreRestoreExec(
     keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId],
+    stateInfo: Option[StatefulOperatorStateInfo],
     child: SparkPlan)
   extends UnaryExecNode with StateStoreReader {
 
@@ -148,10 +151,7 @@ case class StateStoreRestoreExec(
     val numOutputRows = longMetric("numOutputRows")
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      operatorId = getStateId.operatorId,
-      storeName = "default",
-      storeVersion = getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,
@@ -177,7 +177,7 @@ case class StateStoreRestoreExec(
  */
 case class StateStoreSaveExec(
     keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId] = None,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
     outputMode: Option[OutputMode] = None,
     eventTimeWatermark: Option[Long] = None,
     child: SparkPlan)
@@ -189,10 +189,7 @@ case class StateStoreSaveExec(
       "Incorrect planning in IncrementalExecution, outputMode has not been 
set")
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,
@@ -319,7 +316,7 @@ case class StateStoreSaveExec(
 case class StreamingDeduplicateExec(
     keyExpressions: Seq[Attribute],
     child: SparkPlan,
-    stateId: Option[OperatorStateId] = None,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
     eventTimeWatermark: Option[Long] = None)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
 
@@ -331,10 +328,7 @@ case class StreamingDeduplicateExec(
     metrics // force lazy init at driver
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 002c454..48b0ea2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
       }
       awaitTerminationLock.notifyAll()
     }
+    stateStoreCoordinator.deactivateInstances(terminatedQuery.runId)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
index a7e3262..9a7595e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
@@ -17,11 +17,17 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
+import org.apache.spark.sql.functions.count
+import org.apache.spark.util.Utils
 
 class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext 
{
 
@@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
 
   test("report, verify, getLocation") {
     withCoordinatorRef(sc) { coordinatorRef =>
-      val id = StateStoreId("x", 0, 0)
+      val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID)
 
       assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
       assert(coordinatorRef.getLocation(id) === None)
@@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
 
   test("make inactive") {
     withCoordinatorRef(sc) { coordinatorRef =>
-      val id1 = StateStoreId("x", 0, 0)
-      val id2 = StateStoreId("y", 1, 0)
-      val id3 = StateStoreId("x", 0, 1)
+      val runId1 = UUID.randomUUID
+      val runId2 = UUID.randomUUID
+      val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1)
+      val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2)
+      val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1)
       val host = "hostX"
       val exec = "exec1"
 
@@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
         assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true)
       }
 
-      coordinatorRef.deactivateInstances("x")
+      coordinatorRef.deactivateInstances(runId1)
 
       assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false)
       assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
@@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
           Some(ExecutorCacheTaskLocation(host, exec).toString))
       assert(coordinatorRef.getLocation(id3) === None)
 
-      coordinatorRef.deactivateInstances("y")
+      coordinatorRef.deactivateInstances(runId2)
       assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false)
       assert(coordinatorRef.getLocation(id2) === None)
     }
@@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
     withCoordinatorRef(sc) { coordRef1 =>
       val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env)
 
-      val id = StateStoreId("x", 0, 0)
+      val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID)
 
       coordRef1.reportActiveInstance(id, "hostX", "exec1")
 
@@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       }
     }
   }
+
+  test("query stop deactivates related store providers") {
+    var coordRef: StateStoreCoordinatorRef = null
+    try {
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      SparkSession.setActiveSession(spark)
+      import spark.implicits._
+      coordRef = spark.streams.stateStoreCoordinator
+      implicit val sqlContext = spark.sqlContext
+      spark.conf.set("spark.sql.shuffle.partitions", "1")
+
+      // Start a query and run a batch to load state stores
+      val inputData = MemoryStream[Int]
+      val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // 
stateful query
+      val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+      val query = aggregated.writeStream
+        .format("memory")
+        .outputMode("update")
+        .queryName("query")
+        .option("checkpointLocation", checkpointLocation.toString)
+        .start()
+      inputData.addData(1, 2, 3)
+      query.processAllAvailable()
+
+      // Verify state store has been loaded
+      val stateCheckpointDir =
+        
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation
+      val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 
0, 0), query.runId)
+      assert(coordRef.getLocation(providerId).nonEmpty)
+
+      // Stop and verify whether the stores are deactivated in the coordinator
+      query.stop()
+      assert(coordRef.getLocation(providerId).isEmpty)
+    } finally {
+      SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop()))
+      if (coordRef != null) coordRef.stop()
+      StateStore.stop()
+    }
+  }
 }
 
 object StateStoreCoordinatorSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 4a1a089..defb9ed 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
 import java.nio.file.Files
+import java.util.UUID
 
 import scala.util.Random
 
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.scalatest.concurrent.Eventually._
-import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.sql.LocalSparkSession._
-import org.apache.spark.LocalSparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.LocalSparkSession._
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 import org.apache.spark.util.{CompletionIterator, Utils}
 
@@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
   test("versioning and immutability") {
     withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { 
spark =>
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
-      val opId = 0
-      val rdd1 =
-        makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-            spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+      val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
+            spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, 
valueSchema, None)(
             increment)
       assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq("a", 
"c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(
+        spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, 
valueSchema, None)(
         increment)
       assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
@@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
   }
 
   test("recovering from files") {
-    val opId = 0
     val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
 
     def makeStoreRDD(
@@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
         storeVersion: Int): RDD[(String, Int)] = {
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, 
None)(increment)
+        sqlContext, operatorStateInfo(path, version = storeVersion),
+        keySchema, valueSchema, None)(increment)
     }
 
     // Generate RDDs and state store data
@@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
       }
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+        spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, 
valueSchema, None)(
         iteratorOfGets)
       assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" 
-> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
+        sqlContext, operatorStateInfo(path, version = 0), keySchema, 
valueSchema, None)(
         iteratorOfPuts)
       assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", 
"c")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(
+        sqlContext, operatorStateInfo(path, version = 1), keySchema, 
valueSchema, None)(
         iteratorOfGets)
       assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> 
Some(1), "c" -> None))
     }
@@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
 
   test("preferred locations using StateStoreCoordinator") {
     quietly {
+      val queryRunId = UUID.randomUUID
       val opId = 0
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
 
       withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { 
spark =>
         implicit val sqlContext = spark.sqlContext
         val coordinatorRef = sqlContext.streams.stateStoreCoordinator
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, 
"name"), "host1", "exec1")
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, 
"name"), "host2", "exec2")
+        val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 
0), queryRunId)
+        val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 
1), queryRunId)
+        coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1")
+        coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2")
 
-        assert(
-          coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) ===
+        require(
+          coordinatorRef.getLocation(storeProviderId1) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(
-          increment)
+          sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
+          keySchema, valueSchema, None)(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
         val path = Utils.createDirectory(tempDir, 
Random.nextString(10)).toString
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", 
"a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 0, keySchema, 
valueSchema, None)(increment)
+          sqlContext, operatorStateInfo(path, version = 0), keySchema, 
valueSchema, None)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq("a", 
"c")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 1, keySchema, 
valueSchema, None)(increment)
+          sqlContext, operatorStateInfo(path, version = 1), keySchema, 
valueSchema, None)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with 
BeforeAndAfter with BeforeAn
     sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
   }
 
+  private def operatorStateInfo(
+      path: String,
+      queryRunId: UUID = UUID.randomUUID,
+      version: Int = 0): StatefulOperatorStateInfo = {
+    StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version)
+  }
+
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
       val key = stringToRow(s)

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index af2b9f1..c2087ec 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.{File, IOException}
 import java.net.URI
+import java.util.UUID
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
 import org.apache.spark.LocalSparkContext._
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
+import org.apache.spark.sql.functions.count
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -143,7 +147,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     provider.getStore(0).commit()
 
     // Verify we don't leak temp files
-    val tempFiles = FileUtils.listFiles(new 
File(provider.id.checkpointLocation),
+    val tempFiles = FileUtils.listFiles(new 
File(provider.stateStoreId.checkpointRootLocation),
       null, true).asScala.filter(_.getName.startsWith("temp-"))
     assert(tempFiles.isEmpty)
   }
@@ -183,7 +187,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   test("StateStore.get") {
     quietly {
       val dir = newDir()
-      val storeId = StateStoreId(dir, 0, 0)
+      val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), 
UUID.randomUUID)
       val storeConf = StateStoreConf.empty
       val hadoopConf = new Configuration()
 
@@ -243,18 +247,18 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       .set("spark.rpc.numRetries", "1")
     val opId = 0
     val dir = newDir()
-    val storeId = StateStoreId(dir, opId, 0)
+    val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), 
UUID.randomUUID)
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
     val storeConf = StateStoreConf(sqlConf)
     val hadoopConf = new Configuration()
-    val provider = newStoreProvider(storeId)
+    val provider = newStoreProvider(storeProviderId.storeId)
 
     var latestStoreVersion = 0
 
     def generateStoreVersions() {
       for (i <- 1 to 20) {
-        val store = StateStore.get(storeId, keySchema, valueSchema, None,
+        val store = StateStore.get(storeProviderId, keySchema, valueSchema, 
None,
           latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", i)
         store.commit()
@@ -274,7 +278,8 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
           eventually(timeout(timeoutDuration)) {
             // Store should have been reported to the coordinator
-            assert(coordinatorRef.getLocation(storeId).nonEmpty, "active 
instance was not reported")
+            assert(coordinatorRef.getLocation(storeProviderId).nonEmpty,
+              "active instance was not reported")
 
             // Background maintenance should clean up and generate snapshots
             assert(StateStore.isMaintenanceRunning, "Maintenance task is not 
running")
@@ -295,35 +300,35 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
             assert(!fileExists(provider, 1, isSnapshot = false), "earliest 
file not deleted")
           }
 
-          // If driver decides to deactivate all instances of the store, then 
this instance
-          // should be unloaded
-          coordinatorRef.deactivateInstances(dir)
+          // If driver decides to deactivate all stores related to a query run,
+          // then this instance should be unloaded
+          coordinatorRef.deactivateInstances(storeProviderId.queryRunId)
           eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeId))
+            assert(!StateStore.isLoaded(storeProviderId))
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal 
= None,
             latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeId))
+          assert(StateStore.isLoaded(storeProviderId))
 
           // If some other executor loads the store, then this instance should 
be unloaded
-          coordinatorRef.reportActiveInstance(storeId, "other-host", 
"other-exec")
+          coordinatorRef.reportActiveInstance(storeProviderId, "other-host", 
"other-exec")
           eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeId))
+            assert(!StateStore.isLoaded(storeProviderId))
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal 
= None,
             latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeId))
+          assert(StateStore.isLoaded(storeProviderId))
         }
       }
 
       // Verify if instance is unloaded if SparkContext is stopped
       eventually(timeout(timeoutDuration)) {
         require(SparkEnv.get === null)
-        assert(!StateStore.isLoaded(storeId))
+        assert(!StateStore.isLoaded(storeProviderId))
         assert(!StateStore.isMaintenanceRunning)
       }
     }
@@ -344,7 +349,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
   test("SPARK-18416: do not create temp delta file until the store is 
updated") {
     val dir = newDir()
-    val storeId = StateStoreId(dir, 0, 0)
+    val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), 
UUID.randomUUID)
     val storeConf = StateStoreConf.empty
     val hadoopConf = new Configuration()
     val deltaFileDir = new File(s"$dir/0/0/")
@@ -408,12 +413,60 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(numDeltaFiles === 3)
   }
 
+  test("SPARK-21145: Restarted queries create new provider instances") {
+    try {
+      val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+      val spark = SparkSession.builder().master("local[2]").getOrCreate()
+      SparkSession.setActiveSession(spark)
+      implicit val sqlContext = spark.sqlContext
+      spark.conf.set("spark.sql.shuffle.partitions", "1")
+      import spark.implicits._
+      val inputData = MemoryStream[Int]
+
+      def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = {
+        val aggregated = inputData.toDF().groupBy("value").agg(count("*"))
+        // stateful query
+        val query = aggregated.writeStream
+          .format("memory")
+          .outputMode("complete")
+          .queryName("query")
+          .option("checkpointLocation", checkpointLocation.toString)
+          .start()
+        inputData.addData(1, 2, 3)
+        query.processAllAvailable()
+        require(query.lastProgress != null) // at least one batch processed 
after start
+        val loadedProvidersMethod =
+          PrivateMethod[mutable.HashMap[StateStoreProviderId, 
StateStoreProvider]]('loadedProviders)
+        val loadedProvidersMap = StateStore invokePrivate 
loadedProvidersMethod()
+        val loadedProviders = loadedProvidersMap.synchronized { 
loadedProvidersMap.values.toSeq }
+        query.stop()
+        loadedProviders
+      }
+
+      val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders()
+      require(loadedProvidersAfterRun1.length === 1)
+
+      val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders()
+      assert(loadedProvidersAfterRun2.length === 2)   // two providers loaded 
for 2 runs
+
+      // Both providers should have the same StateStoreId, but the should be 
different objects
+      assert(loadedProvidersAfterRun2(0).stateStoreId === 
loadedProvidersAfterRun2(1).stateStoreId)
+      assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1))
+
+    } finally {
+      SparkSession.getActiveSession.foreach { spark =>
+        spark.streams.active.foreach(_.stop())
+        spark.stop()
+      }
+    }
+  }
+
   override def newStoreProvider(): HDFSBackedStateStoreProvider = {
     newStoreProvider(opId = Random.nextInt(), partition = 0)
   }
 
   override def newStoreProvider(storeId: StateStoreId): 
HDFSBackedStateStoreProvider = {
-    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = 
storeId.checkpointLocation)
+    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = 
storeId.checkpointRootLocation)
   }
 
   override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): 
Set[(String, Int)] = {
@@ -423,7 +476,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   override def getData(
     provider: HDFSBackedStateStoreProvider,
     version: Int = -1): Set[(String, Int)] = {
-    val reloadedProvider = newStoreProvider(provider.id)
+    val reloadedProvider = newStoreProvider(provider.stateStoreId)
     if (version < 0) {
       reloadedProvider.latestIterator().map(rowsToStringInt).toSet
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 4ede4fd..86c3a35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider {
     throw new Exception("Successfully instantiated")
   }
 
-  override def id: StateStoreId = null
+  override def stateStoreId: StateStoreId = null
 
   override def close(): Unit = { }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 2a4039c..b2c42ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -26,9 +26,8 @@ import scala.reflect.ClassTag
 import scala.util.Random
 import scala.util.control.NonFatal
 
-import org.scalatest.Assertions
+import org.scalatest.{Assertions, BeforeAndAfterAll}
 import org.scalatest.concurrent.{Eventually, Timeouts}
-import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.Span
@@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, 
ExpressionEncoder, Ro
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
+import org.apache.spark.util.{Clock, SystemClock, Utils}
 
 /**
  * A framework for implementing tests for streaming queries and sources.
@@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, 
SystemClock, Utils}
  * avoid hanging forever in the case of failures. However, individual suites 
can change this
  * by overriding `streamingTimeout`.
  */
-trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
+trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with 
BeforeAndAfterAll {
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    StateStore.stop() // stop the state store maintenance thread and unload 
store providers
+  }
 
   /** How long to wait for an active stream to catch up when checking a 
result. */
   val streamingTimeout = 10.seconds


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to