Repository: spark Updated Branches: refs/heads/master 4bb6a53eb -> fa757ee1d
http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index cc09b2d..af2b9f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -40,15 +40,15 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { +class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] + with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ - import StateStoreSuite._ + import StateStoreTestsHelper._ - private val tempDir = Utils.createTempDir().toString - private val keySchema = StructType(Seq(StructField("key", StringType, true))) - private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) before { StateStore.stop() @@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(!StateStore.isMaintenanceRunning) } - test("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator().isEmpty) - - val store = provider.getStore(0) - assert(!store.hasCommitted) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - - // Verify state after updating - put(store, "a", 1) - assert(store.numKeys() === 1) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - assert(provider.latestIterator().isEmpty) - - // Make updates, commit and then verify state - put(store, "b", 2) - put(store, "aa", 3) - assert(store.numKeys() === 3) - remove(store, _.startsWith("a")) - assert(store.numKeys() === 1) - assert(store.commit() === 1) - - assert(store.hasCommitted) - assert(rowsToSet(store.iterator()) === Set("b" -> 2)) - assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) - assert(fileExists(provider, version = 1, isSnapshot = false)) - - assert(getDataFromFiles(provider) === Set("b" -> 2)) - - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) - } - intercept[Exception] { - getDataFromFiles(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = new HDFSBackedStateStoreProvider( - store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.numKeys() === 1) - put(reloadedStore, "c", 4) - assert(reloadedStore.numKeys() === 2) - assert(reloadedStore.commit() === 2) - assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) - assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) - } - - test("filter and concurrent updates") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator.isEmpty) - val store = provider.getStore(0) - put(store, "a", 1) - put(store, "b", 2) - - // Updates should work while iterating of filtered entries - val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } - filtered.foreach { case (keyRow, valueRow) => - store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) - } - assert(get(store, "a") === Some(2)) - - // Removes should work while iterating of filtered entries - val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } - filtered2.foreach { case (keyRow, _) => - store.remove(keyRow) - } - assert(get(store, "b") === None) - } - - test("updates iterator with all combos of updates and removes") { - val provider = newStoreProvider() - var currentVersion: Int = 0 - - def withStore(body: StateStore => Unit): Unit = { - val store = provider.getStore(currentVersion) - body(store) - currentVersion += 1 - } - - // New data should be seen in updates as value added, even if they had multiple updates - withStore { store => - put(store, "a", 1) - put(store, "aa", 1) - put(store, "aa", 2) - store.commit() - assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) - assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) - } - - // Multiple updates to same key should be collapsed in the updates as a single value update - // Keys that have not been updated should not appear in the updates - withStore { store => - put(store, "a", 4) - put(store, "a", 6) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Keys added, updated and finally removed before commit should not appear in updates - withStore { store => - put(store, "b", 4) // Added, finally removed - put(store, "bb", 5) // Added, updated, finally removed - put(store, "bb", 6) - remove(store, _.startsWith("b")) - store.commit() - assert(updatesToSet(store.updates()) === Set.empty) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Removed data should be seen in updates as a key removed - // Removed, but re-added data should be seen in updates as a value update - withStore { store => - remove(store, _.startsWith("a")) - put(store, "a", 10) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) - assert(rowsToSet(store.iterator()) === Set("a" -> 10)) - } - } - - test("cancel") { - val provider = newStoreProvider() - val store = provider.getStore(0) - put(store, "a", 1) - store.commit() - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - // cancelUpdates should not change the data in the files - val store1 = provider.getStore(1) - put(store1, "b", 1) - store1.abort() - assert(getDataFromFiles(provider) === Set("a" -> 1)) - } - - test("getStore with unexpected versions") { - val provider = newStoreProvider() - - intercept[IllegalArgumentException] { - provider.getStore(-1) - } - - // Prepare some data in the store - val store = provider.getStore(0) - put(store, "a", 1) - assert(store.commit() === 1) - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - intercept[IllegalStateException] { - provider.getStore(2) - } - - // Update store version with some data - val store1 = provider.getStore(1) - put(store1, "b", 1) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - } - test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { @@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } updateVersionTo(2) - require(getDataFromFiles(provider) === Set("a" -> 2)) + require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files - assert(getDataFromFiles(provider) === Set("a" -> 2)) + assert(getData(provider) === Set("a" -> 2)) for (i <- 1 to currentVersion) { assert(fileExists(provider, i, isSnapshot = false)) // all delta files present @@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "snapshot file not generated") deleteFilesEarlierThanVersion(provider, snapshotVersion.get) assert( - getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), "snapshotting messed up the data of the snapshotted version") assert( - getDataFromFiles(provider) === Set("a" -> 6), + getData(provider) === Set("a" -> 6), "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot val latestSnapshotVersion = (0 to 20).filter(version => @@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data") } test("cleaning") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 20) { val store = provider.getStore(i - 1) @@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted // last couple of versions should be retrievable - assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) - assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + assert(getData(provider, 20) === Set("a" -> 20)) + assert(getData(provider, 19) === Set("a" -> 19)) } test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { @@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") - val provider = newStoreProvider(hadoopConf = conf) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) provider.getStore(0).commit() provider.getStore(0).commit() @@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("corrupted file handling") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) put(store, "a", i) @@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) // Corrupt snapshot file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion)) corruptFile(provider, snapshotVersion, isSnapshot = true) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion) + getData(provider, snapshotVersion) } // Corrupt delta file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) corruptFile(provider, snapshotVersion - 1, isSnapshot = false) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } // Delete delta file and verify that it throws error deleteFilesEarlierThanVersion(provider, snapshotVersion) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } } test("StateStore.get") { quietly { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() - // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) } - // Increase version of the store - val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + // Increase version of the store and try to get again + val store0 = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) - assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + val store1 = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1.version === 1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + + // Verify that you can also load older version + val store0reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0reloaded.version === 0) + assert(rowsToSet(store0reloaded.iterator()) === Set.empty) // Verify that you can remove the store and still reload and use it StateStore.unload(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + val store1reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - put(store1, "a", 2) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + assert(store1reloaded.version === 1) + put(store1reloaded, "a", 2) + assert(store1reloaded.commit() === 2) + assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2)) } } @@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, opId, 0) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = new HDFSBackedStateStoreProvider( - storeId, keySchema, valueSchema, storeConf, hadoopConf) + val provider = newStoreProvider(storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + val store = StateStore.get(storeId, keySchema, valueSchema, None, + latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() latestStoreVersion += 1 @@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } @@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ - val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath + val dir = scheme + "://" + newDir() val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) - val provider = newStoreProvider(dir = dir, hadoopConf = conf) + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf) val store = provider.getStore(0) put(store, "a", 0) val e = intercept[IllegalStateException](store.commit()) @@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("SPARK-18416: do not create temp delta file until the store is updated") { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Getting the store should not create temp file val store0 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf) } // Put should create a temp file @@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Remove should create a temp file val store1 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf) } remove(store1, _ == "a") assert(numTempFiles === 1) @@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Commit without any updates should create a delta file val store2 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf) } store2.commit() assert(numTempFiles === 0) assert(numDeltaFiles === 3) } - def getDataFromFiles( - provider: HDFSBackedStateStoreProvider, + override def newStoreProvider(): HDFSBackedStateStoreProvider = { + newStoreProvider(opId = Random.nextInt(), partition = 0) + } + + override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + } + + override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider) + } + + override def getData( + provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = new HDFSBackedStateStoreProvider( - provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedProvider = newStoreProvider(provider.id) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { - reloadedProvider.iterator(version).map(rowsToStringInt).toSet + reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet } } - def assertMap( - testMapOption: Option[MapType], - expectedMap: Map[String, Int]): Unit = { - assert(testMapOption.nonEmpty, "no map present") - val convertedMap = testMapOption.get.map(rowsToStringInt) - assert(convertedMap === expectedMap) + def newStoreProvider( + opId: Long, + partition: Int, + dir: String = newDir(), + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val provider = new HDFSBackedStateStoreProvider() + provider.init( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + indexOrdinal = None, + new StateStoreConf(sqlConf), + hadoopConf) + provider } def fileExists( @@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth filePath.delete() filePath.createNewFile() } +} - def storeLoaded(storeId: StateStoreId): Boolean = { - val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) - val loadedStores = StateStore invokePrivate method() - loadedStores.contains(storeId) - } +abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] + extends SparkFunSuite { + import StateStoreTestsHelper._ - def unloadStore(storeId: StateStoreId): Boolean = { - val method = PrivateMethod('remove) - StateStore invokePrivate method(storeId) - } + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() - def newStoreProvider( - opId: Long = Random.nextLong, - partition: Int = 0, - minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString, - hadoopConf: Configuration = new Configuration() - ): HDFSBackedStateStoreProvider = { - val sqlConf = new SQLConf() - sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) - new HDFSBackedStateStoreProvider( - StateStoreId(dir, opId, partition), - keySchema, - valueSchema, - new StateStoreConf(sqlConf), - hadoopConf) + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + assert(store.numKeys() === 0) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.numKeys() === 1) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + assert(store.numKeys() === 3) + remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) + put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) } - def remove(store: StateStore, condition: String => Boolean): Unit = { - store.remove(row => condition(rowToString(row))) + test("removing while iterating") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" } + filtered.foreach { tuple => + store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" } + filtered2.foreach { tuple => store.remove(tuple.key) } + assert(get(store, "b") === None) } - private def put(store: StateStore, key: String, value: Int): Unit = { - store.put(stringToRow(key), intToRow(value)) + test("abort") { + val provider = newStoreProvider() + val store = provider.getStore(0) + put(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 1) + store1.abort() } - private def get(store: StateStore, key: String): Option[Int] = { - store.get(stringToRow(key)).map(rowToInt) + test("getStore with invalid versions") { + val provider = newStoreProvider() + + def checkInvalidVersion(version: Int): Unit = { + intercept[Exception] { + provider.getStore(version) + } + } + + checkInvalidVersion(-1) + checkInvalidVersion(1) + + val store = provider.getStore(0) + put(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + val store1_ = provider.getStore(1) + assert(rowsToSet(store1_.iterator()) === Set("a" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(2) + + // Update store version with some data + val store1 = provider.getStore(1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + put(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(3) } -} -private[state] object StateStoreSuite { + /** Return a new provider with a random id */ + def newStoreProvider(): ProviderClass + + /** Return a new provider with the given id */ + def newStoreProvider(storeId: StateStoreId): ProviderClass + + /** Get the latest data referred to by the given provider but not using this provider */ + def getLatestData(storeProvider: ProviderClass): Set[(String, Int)] + + /** + * Get a specific version of data referred to by the given provider but not using + * this provider + */ + def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] +} - /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ - trait TestUpdate - case class Added(key: String, value: Int) extends TestUpdate - case class Updated(key: String, value: Int) extends TestUpdate - case class Removed(key: String) extends TestUpdate +object StateStoreTestsHelper { val strProj = UnsafeProjection.create(Array[DataType](StringType)) val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) @@ -692,26 +649,29 @@ private[state] object StateStoreSuite { row.getInt(0) } - def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { - (rowToInt(row._1), rowToInt(row._2)) + def rowsToStringInt(row: UnsafeRowPair): (String, Int) = { + (rowToString(row.key), rowToInt(row.value)) } + def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } - def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { - (rowToString(row._1), rowToInt(row._2)) + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.getRange(None, None).foreach { rowPair => + if (condition(rowToString(rowPair.key))) store.remove(rowPair.key) + } } - def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { - iterator.map(rowsToStringInt).toSet + def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) } - def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { - case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) - case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case ValueRemoved(key, _) => Removed(rowToString(key)) - }.toSet + def get(store: StateStore, key: String): Option[Int] = { + Option(store.get(stringToRow(key))).map(rowToInt) } + + def newDir(): String = Utils.createTempDir().toString } /** http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6bb9408..0d9ca81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("StateStoreUpdater - rows are cloned before writing to StateStore") { - // function for running count - val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { - state.update(state.getOption.getOrElse(0) + values.size) - Iterator.empty - } - val store = newStateStore() - val plan = newFlatMapGroupsWithStateExec(func) - val updater = new plan.StateStoreUpdater(store) - val data = Seq(1, 1, 2) - val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) - returnIter.size // consume the iterator to force store updates - val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet - assert(storeData === Set((1, 2), (2, 1))) - } - test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf callFunction() val updatedStateRow = store.get(key) assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (updatedStateRow != null) { assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, "final timeout timestamp not as expected") } } @@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite { import scala.collection.JavaConverters._ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } } - override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - iterator.filter { case (k, v) => c(k, v) } + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { + map.put(key.copy(), newValue.copy()) } - - override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def remove(condition: (UnsafeRow) => Boolean): Unit = { - iterator.map(_._1).filter(condition).foreach(map.remove) - } override def commit(): Long = version + 1 override def abort(): Unit = { } override def id: StateStoreId = null override def version: Long = 0 - override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } override def numKeys(): Long = map.size override def hasCommitted: Boolean = true } http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 1fc0629..280f2dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -31,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider @@ -614,6 +616,30 @@ class StreamSuite extends StreamTest { assertDescContainsQueryNameAnd(batch = 2) query.stop() } + + testQuietly("specify custom state store provider") { + val queryName = "memStream" + val providerClassName = classOf[TestStateStoreProvider].getCanonicalName + withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { + val input = MemoryStream[Int] + val query = input + .toDS() + .groupBy() + .count() + .writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .start() + input.addData(1, 2, 3) + val e = intercept[Exception] { + query.awaitTermination() + } + + assert(e.getMessage.contains(providerClassName)) + assert(e.getMessage.contains("instantiated")) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -719,3 +745,22 @@ object ThrowingInterruptedIOException { */ @volatile var createSourceLatch: CountDownLatch = null } + +class TestStateStoreProvider extends StateStoreProvider { + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit = { + throw new Exception("Successfully instantiated") + } + + override def id: StateStoreId = null + + override def close(): Unit = { } + + override def getStore(version: Long): StateStore = null +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org