Repository: spark Updated Branches: refs/heads/master b8a743b6a -> fe24634d1
[SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reload StateStoreProviders when query is restarted ## What changes were proposed in this pull request? StateStoreProvider instances are loaded on-demand in a executor when a query is started. When a query is restarted, the loaded provider instance will get reused. Now, there is a non-trivial chance, that the task of the previous query run is still running, while the tasks of the restarted run has started. So for a stateful partition, there may be two concurrent tasks related to the same stateful partition, and there for using the same provider instance. This can lead to inconsistent results and possibly random failures, as state store implementations are not designed to be thread-safe. To fix this, I have introduced a `StateStoreProviderId`, that unique identifies a provider loaded in an executor. It has the query run id in it, thus making sure that restarted queries will force the executor to load a new provider instance, thus avoiding two concurrent tasks (from two different runs) from reusing the same provider instance. Additional minor bug fixes - All state stores related to query run is marked as deactivated in the `StateStoreCoordinator` so that the executors can unload them and clear resources. - Moved the code that determined the checkpoint directory of a state store from implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific code (StateStoreId), so that implementation do not accidentally get it wrong. - Also added store name to the path, to support multiple stores per sql operator partition. *Note:* This change does not address the scenario where two tasks of the same run (e.g. speculative tasks) are concurrently running in the same executor. The chance of this very small, because ideally speculative tasks should never run in the same executor. ## How was this patch tested? Existing unit tests + new unit test. Author: Tathagata Das <tathagata.das1...@gmail.com> Closes #18355 from tdas/SPARK-21145. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fe24634d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fe24634d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fe24634d Branch: refs/heads/master Commit: fe24634d14bc0973ca38222db2f58eafbf0c890d Parents: b8a743b Author: Tathagata Das <tathagata.das1...@gmail.com> Authored: Fri Jun 23 00:43:21 2017 -0700 Committer: Tathagata Das <tathagata.das1...@gmail.com> Committed: Fri Jun 23 00:43:21 2017 -0700 ---------------------------------------------------------------------- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../spark/sql/execution/command/commands.scala | 5 +- .../streaming/FlatMapGroupsWithStateExec.scala | 7 +- .../streaming/IncrementalExecution.scala | 27 +++--- .../execution/streaming/StreamExecution.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 16 ++-- .../execution/streaming/state/StateStore.scala | 91 ++++++++++++++----- .../streaming/state/StateStoreCoordinator.scala | 41 +++++---- .../streaming/state/StateStoreRDD.scala | 21 +++-- .../sql/execution/streaming/state/package.scala | 25 ++---- .../execution/streaming/statefulOperators.scala | 38 ++++---- .../sql/streaming/StreamingQueryManager.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 61 +++++++++++-- .../streaming/state/StateStoreRDDSuite.scala | 51 ++++++----- .../streaming/state/StateStoreSuite.scala | 93 +++++++++++++++----- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../apache/spark/sql/streaming/StreamTest.scala | 13 ++- 17 files changed, 329 insertions(+), 166 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af..12f8cff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -311,7 +311,7 @@ object AggUtils { val saved = StateStoreSaveExec( groupingAttributes, - stateId = None, + stateInfo = None, outputMode = None, eventTimeWatermark = None, partialMerged2) http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2d82fcf..81bc93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -117,7 +119,8 @@ case class ExplainCommand( // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. new IncrementalExecution( - sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, OffsetSeqMetadata(0, 0)) + sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", + UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 2aad870..9dcac33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec( } child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, groupingAttributes.toStructType, stateAttributes.toStructType, indexOrdinal = None, http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 622e049..ab89dc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging @@ -36,6 +37,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val runId: UUID, val currentBatchId: Long, offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -69,7 +71,13 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private val statefulOperatorId = new AtomicInteger(0) + + /** Get the state info of the next stateful operator */ + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + } /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -78,35 +86,28 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - + val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, - Some(stateId), + Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, - Some(stateId), + Some(aggStateInfo), child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, - Some(stateId), + Some(nextStatefulOperationStateInfo), Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) m.copy( - stateId = Some(stateId), + stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 74f0f50..06bdec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -652,6 +652,7 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + runId, currentBatchId, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 67d86da..bae7a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null - override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId override def get(key: UnsafeRow): UnsafeRow = { mapToUpdate.get(key) @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** * Whether all updates have been committed */ - override private[streaming] def hasCommitted: Boolean = { + override def hasCommitted: Boolean = { state == COMMITTED } @@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { - this.stateStoreId = stateStoreId + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema this.storeConf = storeConf @@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit fs.mkdirs(baseDir) } - override def id: StateStoreId = stateStoreId + override def stateStoreId: StateStoreId = stateStoreId_ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def toString(): String = { - s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" } /* Internal fields and methods */ - @volatile private var stateStoreId: StateStoreId = _ + @volatile private var stateStoreId_ : StateStoreId = _ @volatile private var keySchema: StructType = _ @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ private lazy val loadedMaps = new mutable.HashMap[Long, MapType] - private lazy val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fs = baseDir.getFileSystem(hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 29c456f..a94ff8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -24,14 +25,14 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} - /** * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific * version of state data, and such instances are created through a [[StateStoreProvider]]. @@ -99,7 +100,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[streaming] def hasCommitted: Boolean + def hasCommitted: Boolean } @@ -147,7 +148,7 @@ trait StateStoreProvider { * Return the id of the StateStores this provider will generate. * Should be the same as the one passed in init(). */ - def id: StateStoreId + def stateStoreId: StateStoreId /** Called when the provider instance is unloaded from the executor */ def close(): Unit @@ -179,13 +180,46 @@ object StateStoreProvider { } } +/** + * Unique identifier for a provider, used to identify when providers can be reused. + * Note that `queryRunId` is used uniquely identify a provider, so that the same provider + * instance is not reused across query restarts. + */ +case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID) -/** Unique identifier for a bunch of keyed state data. */ +/** + * Unique identifier for a bunch of keyed state data. + * @param checkpointRootLocation Root directory where all the state data of a query is stored + * @param operatorId Unique id of a stateful operator + * @param partitionId Index of the partition of an operators state data + * @param storeName Optional, name of the store. Each partition can optionally use multiple state + * stores, but they have to be identified by distinct names. + */ case class StateStoreId( - checkpointLocation: String, + checkpointRootLocation: String, operatorId: Long, partitionId: Int, - name: String = "") + storeName: String = StateStoreId.DEFAULT_STORE_NAME) { + + /** + * Checkpoint directory to be used by a single state store, identified uniquely by the tuple + * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should + * use this path for saving state data, as this ensures that distinct stores will write to + * different locations. + */ + def storeCheckpointLocation(): Path = { + if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + // For reading state store data that was generated before store names were used (Spark <= 2.2) + new Path(checkpointRootLocation, s"$operatorId/$partitionId") + } else { + new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName") + } + } +} + +object StateStoreId { + val DEFAULT_STORE_NAME = "default" +} /** Mutable, and reusable class for representing a pair of UnsafeRows. */ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { @@ -211,7 +245,7 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @GuardedBy("loadedProviders") - private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -253,7 +287,7 @@ object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, + storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -264,24 +298,24 @@ object StateStore extends Logging { val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, + storeProviderId, StateStoreProvider.instantiate( - storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) - reportActiveStoreInstance(storeId) + reportActiveStoreInstance(storeProviderId) provider } storeProvider.getStore(version) } /** Unload a state store provider */ - def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId).foreach(_.close()) + def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ - def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { - loadedProviders.contains(storeId) + def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeProviderId) } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { @@ -340,21 +374,21 @@ object StateStore extends Logging { } } - private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = { if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) - logDebug(s"Reported that the loaded instance $storeId is active") + coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId)) + logInfo(s"Reported that the loaded instance $storeProviderId is active") } } - private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = { if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = - coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified") + coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified") verified } else { false @@ -364,12 +398,21 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - if (_coordRef == null) { + logInfo("Env is not null") + val isDriver = + env.executorId == SparkContext.DRIVER_IDENTIFIER || + env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in _coordRef may be have become inactive + // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, + // always recreate the reference. + if (isDriver || _coordRef == null) { + logInfo("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } - logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { + logInfo("Env is null") _coordRef = null None } http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d0f8188..3884f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.collection.mutable import org.apache.spark.SparkEnv @@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils private sealed trait StateStoreCoordinatorMessage extends Serializable /** Classes representing messages */ -private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) +private case class ReportActiveInstance( + storeId: StateStoreProviderId, + host: String, + executorId: String) extends StateStoreCoordinatorMessage -private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) +private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String) extends StateStoreCoordinatorMessage -private case class GetLocation(storeId: StateStoreId) +private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(checkpointLocation: String) +private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging { class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( - storeId: StateStoreId, + stateStoreProviderId: StateStoreProviderId, host: String, executorId: String): Unit = { - rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ - private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) + private[state] def verifyIfInstanceActive( + stateStoreProviderId: StateStoreProviderId, + executorId: String): Boolean = { + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId)) } /** Get the location of the state store */ - private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) + private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = { + rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId)) } - /** Deactivate instances related to a set of operator */ - private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) + /** Deactivate instances related to a query */ + private[sql] def deactivateInstances(runId: UUID): Unit = { + rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } private[state] def stop(): Unit = { @@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) - case DeactivateInstances(checkpointLocation) => + case DeactivateInstances(runId) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq + instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove - logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b744c25..01d8e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} @@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, @@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) - storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) + val storeProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + store = StateStore.get( - storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 228fe86..a0086e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.TaskContext @@ -32,20 +34,14 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo, keySchema, valueSchema, indexOrdinal, @@ -56,10 +52,7 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -79,10 +72,10 @@ package object state { new StateStoreRDD( dataRDD, wrappedF, - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo.checkpointLocation, + stateInfo.queryRunId, + stateInfo.operatorId, + stateInfo.storeVersion, keySchema, valueSchema, indexOrdinal, http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3e57f3f..c572246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit._ import org.apache.spark.rdd.RDD @@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ -case class OperatorStateId( +case class StatefulOperatorStateInfo( checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - batchId: Long) + storeVersion: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + * An operator that reads or writes state from the [[StateStore]]. + * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in + * [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] + def stateInfo: Option[StatefulOperatorStateInfo] - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { + stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") } } @@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode { */ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], child: SparkPlan) extends UnaryExecNode with StateStoreReader { @@ -148,10 +151,7 @@ case class StateStoreRestoreExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeName = "default", - storeVersion = getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -177,7 +177,7 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) @@ -189,10 +189,7 @@ case class StateStoreSaveExec( "Incorrect planning in IncrementalExecution, outputMode has not been set") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -319,7 +316,7 @@ case class StateStoreSaveExec( case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, eventTimeWatermark: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -331,10 +328,7 @@ case class StreamingDeduplicateExec( metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 002c454..48b0ea2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e3262..9a7595e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4a1a089..defb9ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) @@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( - increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af2b9f1..c2087ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } @@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("StateStore.get") { quietly { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .set("spark.rpc.numRetries", "1") val opId = 0 val dir = newDir() - val storeId = StateStoreId(dir, opId, 0) + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeId) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, None, + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() @@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-18416: do not create temp delta file until the store is updated") { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(numDeltaFiles === 3) } + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { - newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) } override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { @@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] override def getData( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = newStoreProvider(provider.id) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 4ede4fd..86c3a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider { throw new Exception("Successfully instantiated") } - override def id: StateStoreId = null + override def stateStoreId: StateStoreId = null override def close(): Unit = { } http://git-wip-us.apache.org/repos/asf/spark/blob/fe24634d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 2a4039c..b2c42ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org