[SPARK-20883][SPARK-20376][SS] Refactored StateStore APIs and added conf to choose implementation
## What changes were proposed in this pull request? A bunch of changes to the StateStore APIs and implementation. Current state store API has a bunch of problems that causes too many transient objects causing memory pressure. - `StateStore.get(): Option` forces creation of Some/None objects for every get. Changed this to return the row or null. - `StateStore.iterator(): (UnsafeRow, UnsafeRow)` forces creation of new tuple for each record returned. Changed this to return a UnsafeRowTuple which can be reused across records. - `StateStore.updates()` requires the implementation to keep track of updates, while this is used minimally (only by Append mode in streaming aggregations). Removed updates() and updated StateStoreSaveExec accordingly. - `StateStore.filter(condition)` and `StateStore.remove(condition)` has been merge into a single API `getRange(start, end)` which allows a state store to do optimized range queries (i.e. avoid full scans). Stateful operators have been updated accordingly. - Removed a lot of unnecessary row copies Each operator copied rows before calling StateStore.put() even if the implementation does not require it to be copied. It is left up to the implementation on whether to copy the row or not. Additionally, - Added a name to the StateStoreId so that each operator+partition can use multiple state stores (different names) - Added a configuration that allows the user to specify which implementation to use. - Added new metrics to understand the time taken to update keys, remove keys and commit all changes to the state store. These metrics will be visible on the plan diagram in the SQL tab of the UI. - Refactored unit tests such that they can be reused to test any implementation of StateStore. ## How was this patch tested? Old and new unit tests Author: Tathagata Das <tathagata.das1...@gmail.com> Closes #18107 from tdas/SPARK-20376. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fa757ee1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fa757ee1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fa757ee1 Branch: refs/heads/master Commit: fa757ee1d41396ad8734a3f2dd045bb09bc82a2e Parents: 4bb6a53 Author: Tathagata Das <tathagata.das1...@gmail.com> Authored: Tue May 30 15:33:06 2017 -0700 Committer: Shixiong Zhu <shixi...@databricks.com> Committed: Tue May 30 15:33:06 2017 -0700 ---------------------------------------------------------------------- .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../streaming/FlatMapGroupsWithStateExec.scala | 39 +- .../state/HDFSBackedStateStoreProvider.scala | 218 +++----- .../execution/streaming/state/StateStore.scala | 163 ++++-- .../streaming/state/StateStoreConf.scala | 28 +- .../streaming/state/StateStoreRDD.scala | 11 +- .../sql/execution/streaming/state/package.scala | 11 +- .../execution/streaming/statefulOperators.scala | 142 +++-- .../streaming/state/StateStoreRDDSuite.scala | 41 +- .../streaming/state/StateStoreSuite.scala | 534 +++++++++---------- .../streaming/FlatMapGroupsWithStateSuite.scala | 40 +- .../spark/sql/streaming/StreamSuite.scala | 45 ++ 12 files changed, 695 insertions(+), 588 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c5d69c2..c6f5cf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -552,6 +552,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = + buildConf("spark.sql.streaming.stateStore.providerClass") + .internal() + .doc( + "The class used to manage state data in stateful streaming queries. This class must " + + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") + .stringConf + .createOptional + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") .internal() @@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 3ceb4cf..2aad870 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, groupingAttributes.toStructType, stateAttributes.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) @@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.filter { case (_, stateRow) => - val timeoutTimestamp = getTimeoutTimestamp(stateRow) + val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) + timingOutKeys.flatMap { rowPair => + callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty } @@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec( * Call the user function on a key's data, update the state store, and return the return data * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. + * + * @param keyRow Row representing the key, cannot be null + * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty + * @param prevStateRow Row representing the previous state, can be null + * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow], - prevStateRowOption: Option[UnsafeRow], + prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = getStateObj(prevStateRowOption) + val stateObj = getStateObj(prevStateRow) val keyedState = GroupStateImpl.createForStreaming( - stateObjOption, + Option(stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec( numUpdatedStateRows += 1 } else { - val previousTimeoutTimestamp = prevStateRowOption match { - case Some(row) => getTimeoutTimestamp(row) - case None => NO_TIMESTAMP - } + val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { - prevStateRowOption.orNull + prevStateRow } val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp @@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException("Attempting to write empty state") } setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow.copy(), stateRowToWrite.copy()) + store.put(keyRow, stateRowToWrite) numUpdatedStateRows += 1 } } @@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec( } /** Returns the state as Java object if defined */ - def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { - stateRowOption.map(getStateObjFromRow) + def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null } /** Returns the row for an updated state */ def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) getStateRowFromObj(obj) } /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP + if (isTimeoutEnabled && stateRow != null) { + stateRow.getLong(timeoutTimestampIndex) + } else NO_TIMESTAMP } /** Set the timestamp in a state row */ http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index fb2bf47..67d86da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -67,13 +67,7 @@ import org.apache.spark.util.Utils * to ensure re-executed RDD operations re-apply updates on the correct past version of the * store. */ -private[state] class HDFSBackedStateStoreProvider( - val id: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - storeConf: StateStoreConf, - hadoopConf: Configuration - ) extends StateStoreProvider with Logging { +private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging { // ConcurrentHashMap is used because it generates fail-safe iterators on filtering // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in @@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() - @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - override def get(key: UnsafeRow): Option[UnsafeRow] = { - Option(mapToUpdate.get(key)) - } - - override def filter( - condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - mapToUpdate - .entrySet - .asScala - .iterator - .filter { entry => condition(entry.getKey, entry.getValue) } - .map { entry => (entry.getKey, entry.getValue) } + override def get(key: UnsafeRow): UnsafeRow = { + mapToUpdate.get(key) } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") - - val isNewKey = !mapToUpdate.containsKey(key) - mapToUpdate.put(key, value) - - Option(allUpdates.get(key)) match { - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added already, keep it marked as added - allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => - // Value existed in previous version and updated/removed, mark it as updated - allUpdates.put(key, ValueUpdated(key, value)) - case None => - // There was no prior update, so mark this as added or updated according to its presence - // in previous version. - val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) - allUpdates.put(key, update) - } - writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + val keyCopy = key.copy() + val valueCopy = value.copy() + mapToUpdate.put(keyCopy, valueCopy) + writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) } - /** Remove keys that match the following condition */ - override def remove(condition: UnsafeRow => Boolean): Unit = { + override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - val entryIter = mapToUpdate.entrySet().iterator() - while (entryIter.hasNext) { - val entry = entryIter.next - if (condition(entry.getKey)) { - val value = entry.getValue - val key = entry.getKey - entryIter.remove() - - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + val prevValue = mapToUpdate.remove(key) + if (prevValue != null) { + writeRemoveToDeltaFile(tempDeltaFileStream, key) } } - /** Remove a single key. */ - override def remove(key: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") - if (mapToUpdate.containsKey(key)) { - val value = mapToUpdate.remove(key) - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + override def getRange( + start: Option[UnsafeRow], + end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + verify(state == UPDATING, "Cannot getRange after already committed or aborted") + iterator() } /** Commit all the updates that have been made to the store, and return the new version. */ @@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider( * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. */ - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, - "Cannot get iterator of store data before committing or after aborting") - HDFSBackedStateStoreProvider.this.iterator(newVersion) - } - - /** - * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing all the updates made in the current thread. - */ - override def updates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, - "Cannot get iterator of updates before committing or after aborting") - allUpdates.values().asScala.toIterator + override def iterator(): Iterator[UnsafeRowPair] = { + val unsafeRowPair = new UnsafeRowPair() + mapToUpdate.entrySet.asScala.iterator.map { entry => + unsafeRowPair.withRows(entry.getKey, entry.getValue) + } } override def numKeys(): Long = mapToUpdate.size() @@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider( store } + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): Unit = { + this.stateStoreId = stateStoreId + this.keySchema = keySchema + this.valueSchema = valueSchema + this.storeConf = storeConf + this.hadoopConf = hadoopConf + fs.mkdirs(baseDir) + } + + override def id: StateStoreId = stateStoreId + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { try { @@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider( } } + override def close(): Unit = { + loadedMaps.values.foreach(_.clear()) + } + override def toString(): String = { s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } - /* Internal classes and methods */ + /* Internal fields and methods */ - private val loadedMaps = new mutable.HashMap[Long, MapType] - private val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") - private val fs = baseDir.getFileSystem(hadoopConf) - private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + @volatile private var stateStoreId: StateStoreId = _ + @volatile private var keySchema: StructType = _ + @volatile private var valueSchema: StructType = _ + @volatile private var storeConf: StateStoreConf = _ + @volatile private var hadoopConf: Configuration = _ - initialize() + private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) @@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider( * Get iterator of all the data of the latest version of the store. * Note that this will look up the files to determined the latest known version. */ - private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet val versionsLoaded = loadedMaps.keySet val allKnownVersions = versionsInFiles ++ versionsLoaded + val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry => + unsafeRowTuple.withRows(entry.getKey, entry.getValue) } } else Iterator.empty } - /** Get iterator of a specific version of the store */ - private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { - loadMap(version).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) - } - } - - /** Initialize the store provider */ - private def initialize(): Unit = { - try { - fs.mkdirs(baseDir) - } catch { - case e: IOException => - throw new IllegalStateException( - s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e) - } - } - /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { if (version <= 0) return new MapType @@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider( } } - private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { - - def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - val valueBytes = value.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(valueBytes.size) - output.write(valueBytes) - } - - def writeRemove(key: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(-1) - } + private def writeUpdateToDeltaFile( + output: DataOutputStream, + key: UnsafeRow, + value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } - update match { - case ValueAdded(key, value) => - writeUpdate(key, value) - case ValueUpdated(key, value) => - writeUpdate(key, value) - case ValueRemoved(key, value) => - writeRemove(key) - } + private def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) } private def finalizeDeltaFile(output: DataOutputStream): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index eaa558e..29c456f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils - - -/** Unique identifier for a [[StateStore]] */ -case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) +import org.apache.spark.util.{ThreadUtils, Utils} /** - * Base trait for a versioned key-value store used for streaming aggregations + * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific + * version of state data, and such instances are created through a [[StateStoreProvider]]. */ trait StateStore { @@ -47,50 +44,54 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** Get the current value of a key. */ - def get(key: UnsafeRow): Option[UnsafeRow] - /** - * Return an iterator of key-value pairs that satisfy a certain condition. - * Note that the iterator must be fail-safe towards modification to the store, that is, - * it must be based on the snapshot of store the time of this call, and any change made to the - * store while iterating through iterator should not cause the iterator to fail or have - * any affect on the values in the iterator. + * Get the current value of a non-null key. + * @return a non-null row if the key exists in the store, otherwise null. */ - def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + def get(key: UnsafeRow): UnsafeRow - /** Put a new value for a key. */ + /** + * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in + * the params can be reused, and must make copies of the data as needed for persistence. + */ def put(key: UnsafeRow, value: UnsafeRow): Unit /** - * Remove keys that match the following condition. + * Remove a single non-null key. */ - def remove(condition: UnsafeRow => Boolean): Unit + def remove(key: UnsafeRow): Unit /** - * Remove a single key. + * Get key value pairs with optional approximate `start` and `end` extents. + * If the State Store implementation maintains indices for the data based on the optional + * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use + * `start` and `end` to make a best-effort scan over the data. Default implementation returns + * the full data scan iterator, which is correct but inefficient. Custom implementations must + * ensure that updates (puts, removes) can be made while iterating over this iterator. + * + * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value. + * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value. + * @return An iterator of key-value pairs that is guaranteed not miss any key between start and + * end, both inclusive. */ - def remove(key: UnsafeRow): Unit + def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + iterator() + } /** * Commit all the updates that have been made to the store, and return the new version. + * Implementations should ensure that no more updates (puts, removes) can be after a commit in + * order to avoid incorrect usage. */ def commit(): Long - /** Abort all the updates that have been made to the store. */ - def abort(): Unit - /** - * Iterator of store data after a set of updates have been committed. - * This can be called only after committing all the updates made in the current thread. + * Abort all the updates that have been made to the store. Implementations should ensure that + * no more updates (puts, removes) can be after an abort in order to avoid incorrect usage. */ - def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + def abort(): Unit - /** - * Iterator of the updates that have been committed. - * This can be called only after committing all the updates made in the current thread. - */ - def updates(): Iterator[StoreUpdate] + def iterator(): Iterator[UnsafeRowPair] /** Number of keys in the state store */ def numKeys(): Long @@ -102,28 +103,98 @@ trait StateStore { } -/** Trait representing a provider of a specific version of a [[StateStore]]. */ +/** + * Trait representing a provider that provide [[StateStore]] instances representing + * versions of state data. + * + * The life cycle of a provider and its provide stores are as follows. + * + * - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when + * the first batch of a streaming query is executed on the executor. All subsequent batches reuse + * this provider instance until the query is stopped. + * + * - Every batch of streaming data request a specific version of the state data by invoking + * `getStore(version)` which returns an instance of [[StateStore]] through which the required + * version of the data can be accessed. It is the responsible of the provider to populate + * this store with context information like the schema of keys and values, etc. + * + * - After the streaming query is stopped, the created provider instances are lazily disposed off. + */ trait StateStoreProvider { - /** Get the store with the existing version. */ + /** + * Initialize the provide with more contextual information from the SQL operator. + * This method will be called first after creating an instance of the StateStoreProvider by + * reflection. + * + * @param stateStoreId Id of the versioned StateStores that this provider will generate + * @param keySchema Schema of keys to be stored + * @param valueSchema Schema of value to be stored + * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by + * which the StateStore implementation could index the data. + * @param storeConfs Configurations used by the StateStores + * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data + */ + def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyIndexOrdinal: Option[Int], // for sorting the data by their keys + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit + + /** + * Return the id of the StateStores this provider will generate. + * Should be the same as the one passed in init(). + */ + def id: StateStoreId + + /** Called when the provider instance is unloaded from the executor */ + def close(): Unit + + /** Return an instance of [[StateStore]] representing state data of the given version */ def getStore(version: Long): StateStore - /** Optional method for providers to allow for background maintenance */ + /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } } - -/** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate { - def key: UnsafeRow - def value: UnsafeRow +object StateStoreProvider { + /** + * Return a provider instance of the given provider class. + * The instance will be already initialized. + */ + def instantiate( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStoreProvider = { + val providerClass = storeConf.providerClass.map(Utils.classForName) + .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] + provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + provider + } } -case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Unique identifier for a bunch of keyed state data. */ +case class StateStoreId( + checkpointLocation: String, + operatorId: Long, + partitionId: Int, + name: String = "") -case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Mutable, and reusable class for representing a pair of UnsafeRows. */ +class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { + def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = { + this.key = key + this.value = value + this + } +} /** @@ -185,6 +256,7 @@ object StateStore extends Logging { storeId: StateStoreId, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], version: Long, storeConf: StateStoreConf, hadoopConf: Configuration): StateStore = { @@ -193,7 +265,9 @@ object StateStore extends Logging { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( storeId, - new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + StateStoreProvider.instantiate( + storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + ) reportActiveStoreInstance(storeId) provider } @@ -202,7 +276,7 @@ object StateStore extends Logging { /** Unload a state store provider */ def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId) + loadedProviders.remove(storeId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ @@ -216,6 +290,7 @@ object StateStore extends Logging { /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { + loadedProviders.keySet.foreach { key => unload(key) } loadedProviders.clear() _coordRef = null if (maintenanceTask != null) { http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index acfaa8e..bab297c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +class StateStoreConf(@transient private val sqlConf: SQLConf) + extends Serializable { def this() = this(new SQLConf) - val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot - - val minVersionsToRetain = conf.minBatchesToRetain + /** + * Minimum number of delta files in a chain after which HDFSBackedStateStore will + * consider generating a snapshot. + */ + val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot + + /** Minimum versions a State Store implementation should retain to allow rollbacks */ + val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + + /** + * Optional fully qualified name of the subclass of [[StateStoreProvider]] + * managing state data. That is, the implementation of the State Store to use. + */ + val providerClass: Option[String] = sqlConf.stateStoreProviderClass + + /** + * Additional configurations related to state store. This will capture all configs in + * SQLConf that start with `spark.sql.streaming.stateStore.` */ + val confs: Map[String, String] = + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) } -private[streaming] object StateStoreConf { +object StateStoreConf { val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index e16dda8..b744c25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) extends RDD[U](dataRDD) { @@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( private val storeConf = new StateStoreConf(sessionState.conf) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = dataRDD.context.broadcast( + private val hadoopConfBroadcast = dataRDD.context.broadcast( new SerializableConfiguration(sessionState.newHadoopConf())) override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) } http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 589042a..228fe86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -34,17 +34,21 @@ package object state { sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, - valueSchema: StructType)( + valueSchema: StructType, + indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator))( storeUpdateFunction) @@ -54,9 +58,11 @@ package object state { private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { @@ -69,14 +75,17 @@ package object state { }) cleanedF(store, iter) } + new StateStoreRDD( dataRDD, wrappedF, checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sessionState, storeCoordinator) } http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8dbda29..3e57f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.TimeUnit._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ @@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { +trait StateStoreWriter extends StatefulOperator { self: SparkPlan => + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), + "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"), + "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"), + "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes") + ) + + /** Records the duration of running `body` for the next query progress update. */ + protected def timeTakenMs(body: => Unit): Long = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + math.max(NANOSECONDS.toMillis(endTime - startTime), 0) + } } /** An operator that supports watermark. */ @@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode { /** Predicate based on the child output that matches data older than the watermark. */ lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) + + protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + store.getRange(None, None).foreach { rowPair => + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + } + } + } + } } /** @@ -126,9 +150,11 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, + storeName = "default", storeVersion = getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -136,7 +162,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: savedState.toSeq + row +: Option(savedState).toSeq } } } @@ -165,54 +191,88 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } + } + allRemovalsTimeMs += 0 + commitTimeMs += timeTakenMs { + store.commit() } - store.commit() numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => + store.iterator().map { rowPair => numOutputRows += 1 - v.asInstanceOf[InternalRow] + rowPair.value } // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) + while (filteredIter.hasNext) { + val row = filteredIter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } } - // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicateForKeys.get.eval _) - store.commit() + val removalStartTimeNs = System.nanoTime + val rangeIter = store.getRange(None, None) + + new NextIterator[InternalRow] { + override protected def getNext(): InternalRow = { + var removedValueRow: InternalRow = null + while(rangeIter.hasNext && removedValueRow == null) { + val rowPair = rangeIter.next() + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + removedValueRow = rowPair.value + } + } + if (removedValueRow == null) { + finished = true + null + } else { + removedValueRow + } + } - numTotalStateRows += store.numKeys() - store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => - numOutputRows += 1 - removed.value.asInstanceOf[InternalRow] + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + numTotalStateRows += store.numKeys() + } } // Update and output modified rows from the StateStore. case Some(Update) => + val updatesStartTimeNs = System.nanoTime + new Iterator[InternalRow] { // Filter late date using watermark if specified @@ -223,11 +283,11 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + // Remove old aggregates if watermark specified - if (watermarkPredicateForKeys.nonEmpty) { - store.remove(watermarkPredicateForKeys.get.eval _) - } - store.commit() + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() false } else { @@ -238,7 +298,7 @@ case class StateStoreSaveExec( override def next(): InternalRow = { val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key.copy(), row.copy()) + store.put(key, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -273,27 +333,34 @@ case class StreamingDeduplicateExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } + val updatesStartTimeNs = System.nanoTime + val result = baseIterator.filter { r => val row = r.asInstanceOf[UnsafeRow] val key = getKey(row) val value = store.get(key) - if (value.isEmpty) { - store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + if (value == null) { + store.put(key, StreamingDeduplicateExec.EMPTY_ROW) numUpdatedStateRows += 1 numOutputRows += 1 true @@ -304,8 +371,9 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) - store.commit() + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() }) } http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index bd197be..4a1a089 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + import StateStoreTestsHelper._ + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) - private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) - import StateStoreSuite._ - after { StateStore.stop() } @@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { val resIterator = iter.map { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) val newValue = oldValue + 1 store.put(key, intToRow(newValue)) (s, newValue) @@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn iter: Iterator[String]): Iterator[(String, Option[Int])] = { iter.map { s => val key = stringToRow(s) - val value = store.get(key).map(rowToInt) + val value = Option(store.get(key)).map(rowToInt) (s, value) } } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + increment) require(rdd.partitions.length === 2) assert( @@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) store.put(key, intToRow(oldValue + 1)) } store.commit() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org