Repository: spark
Updated Branches:
  refs/heads/master 4bb6a53eb -> fa757ee1d


http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index cc09b2d..af2b9f1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -40,15 +40,15 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
-class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with 
PrivateMethodTester {
+class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
+  with BeforeAndAfter with PrivateMethodTester {
   type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
 
   import StateStoreCoordinatorSuite._
-  import StateStoreSuite._
+  import StateStoreTestsHelper._
 
-  private val tempDir = Utils.createTempDir().toString
-  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  private val valueSchema = StructType(Seq(StructField("value", IntegerType, 
true)))
+  val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
 
   before {
     StateStore.stop()
@@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     require(!StateStore.isMaintenanceRunning)
   }
 
-  test("get, put, remove, commit, and all data iterator") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator().isEmpty)
-
-    val store = provider.getStore(0)
-    assert(!store.hasCommitted)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-
-    // Verify state after updating
-    put(store, "a", 1)
-    assert(store.numKeys() === 1)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-    assert(provider.latestIterator().isEmpty)
-
-    // Make updates, commit and then verify state
-    put(store, "b", 2)
-    put(store, "aa", 3)
-    assert(store.numKeys() === 3)
-    remove(store, _.startsWith("a"))
-    assert(store.numKeys() === 1)
-    assert(store.commit() === 1)
-
-    assert(store.hasCommitted)
-    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
-    assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2))
-    assert(fileExists(provider, version = 1, isSnapshot = false))
-
-    assert(getDataFromFiles(provider) === Set("b" -> 2))
-
-    // Trying to get newer versions should fail
-    intercept[Exception] {
-      provider.getStore(2)
-    }
-    intercept[Exception] {
-      getDataFromFiles(provider, 2)
-    }
-
-    // New updates to the reloaded store with new version, and does not change 
old version
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      store.id, keySchema, valueSchema, StateStoreConf.empty, new 
Configuration)
-    val reloadedStore = reloadedProvider.getStore(1)
-    assert(reloadedStore.numKeys() === 1)
-    put(reloadedStore, "c", 4)
-    assert(reloadedStore.numKeys() === 2)
-    assert(reloadedStore.commit() === 2)
-    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2))
-    assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
-  }
-
-  test("filter and concurrent updates") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator.isEmpty)
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    put(store, "b", 2)
-
-    // Updates should work while iterating of filtered entries
-    val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == 
"a" }
-    filtered.foreach { case (keyRow, valueRow) =>
-      store.put(keyRow, intToRow(rowToInt(valueRow) + 1))
-    }
-    assert(get(store, "a") === Some(2))
-
-    // Removes should work while iterating of filtered entries
-    val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == 
"b" }
-    filtered2.foreach { case (keyRow, _) =>
-      store.remove(keyRow)
-    }
-    assert(get(store, "b") === None)
-  }
-
-  test("updates iterator with all combos of updates and removes") {
-    val provider = newStoreProvider()
-    var currentVersion: Int = 0
-
-    def withStore(body: StateStore => Unit): Unit = {
-      val store = provider.getStore(currentVersion)
-      body(store)
-      currentVersion += 1
-    }
-
-    // New data should be seen in updates as value added, even if they had 
multiple updates
-    withStore { store =>
-      put(store, "a", 1)
-      put(store, "aa", 1)
-      put(store, "aa", 2)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 
2)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
-    }
-
-    // Multiple updates to same key should be collapsed in the updates as a 
single value update
-    // Keys that have not been updated should not appear in the updates
-    withStore { store =>
-      put(store, "a", 4)
-      put(store, "a", 6)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Keys added, updated and finally removed before commit should not appear 
in updates
-    withStore { store =>
-      put(store, "b", 4)     // Added, finally removed
-      put(store, "bb", 5)    // Added, updated, finally removed
-      put(store, "bb", 6)
-      remove(store, _.startsWith("b"))
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set.empty)
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Removed data should be seen in updates as a key removed
-    // Removed, but re-added data should be seen in updates as a value update
-    withStore { store =>
-      remove(store, _.startsWith("a"))
-      put(store, "a", 10)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 10), 
Removed("aa")))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 10))
-    }
-  }
-
-  test("cancel") {
-    val provider = newStoreProvider()
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    store.commit()
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    // cancelUpdates should not change the data in the files
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    store1.abort()
-    assert(getDataFromFiles(provider) === Set("a" -> 1))
-  }
-
-  test("getStore with unexpected versions") {
-    val provider = newStoreProvider()
-
-    intercept[IllegalArgumentException] {
-      provider.getStore(-1)
-    }
-
-    // Prepare some data in the store
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    assert(store.commit() === 1)
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    intercept[IllegalStateException] {
-      provider.getStore(2)
-    }
-
-    // Update store version with some data
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    assert(store1.commit() === 2)
-    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
-    assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
-  }
-
   test("snapshotting") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, 
minDeltasForSnapshot = 5)
 
     var currentVersion = 0
     def updateVersionTo(targetVersion: Int): Unit = {
@@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     }
 
     updateVersionTo(2)
-    require(getDataFromFiles(provider) === Set("a" -> 2))
+    require(getData(provider) === Set("a" -> 2))
     provider.doMaintenance()               // should not generate snapshot 
files
-    assert(getDataFromFiles(provider) === Set("a" -> 2))
+    assert(getData(provider) === Set("a" -> 2))
 
     for (i <- 1 to currentVersion) {
       assert(fileExists(provider, i, isSnapshot = false))  // all delta files 
present
@@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // After version 6, snapshotting should generate one snapshot file
     updateVersionTo(6)
-    require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated 
correctly")
+    require(getData(provider) === Set("a" -> 6), "store not updated correctly")
     provider.doMaintenance()       // should generate snapshot files
 
     val snapshotVersion = (0 to 6).find(version => fileExists(provider, 
version, isSnapshot = true))
     assert(snapshotVersion.nonEmpty, "snapshot file not generated")
     deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
     assert(
-      getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> 
snapshotVersion.get),
+      getData(provider, snapshotVersion.get) === Set("a" -> 
snapshotVersion.get),
       "snapshotting messed up the data of the snapshotted version")
     assert(
-      getDataFromFiles(provider) === Set("a" -> 6),
+      getData(provider) === Set("a" -> 6),
       "snapshotting messed up the data of the final version")
 
     // After version 20, snapshotting should generate newer snapshot files
     updateVersionTo(20)
-    require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated 
correctly")
+    require(getData(provider) === Set("a" -> 20), "store not updated 
correctly")
     provider.doMaintenance()       // do snapshot
 
     val latestSnapshotVersion = (0 to 20).filter(version =>
@@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot 
not generated")
 
     deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
-    assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed 
up the data")
+    assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the 
data")
   }
 
   test("cleaning") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, 
minDeltasForSnapshot = 5)
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
@@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     assert(!fileExists(provider, version = 1, isSnapshot = false)) // first 
file should be deleted
 
     // last couple of versions should be retrievable
-    assert(getDataFromFiles(provider, 20) === Set("a" -> 20))
-    assert(getDataFromFiles(provider, 19) === Set("a" -> 19))
+    assert(getData(provider, 20) === Set("a" -> 20))
+    assert(getData(provider, 19) === Set("a" -> 19))
   }
 
   test("SPARK-19677: Committing a delta file atop an existing one should not 
fail on HDFS") {
@@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName)
     conf.set("fs.defaultFS", "fake:///")
 
-    val provider = newStoreProvider(hadoopConf = conf)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, 
hadoopConf = conf)
     provider.getStore(0).commit()
     provider.getStore(0).commit()
 
@@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
   }
 
   test("corrupted file handling") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, 
minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
       put(store, "a", i)
@@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
       fileExists(provider, version, isSnapshot = 
true)).getOrElse(fail("snapshot file not found"))
 
     // Corrupt snapshot file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> 
snapshotVersion))
+    assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion))
     corruptFile(provider, snapshotVersion, isSnapshot = true)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion)
+      getData(provider, snapshotVersion)
     }
 
     // Corrupt delta file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> 
(snapshotVersion - 1)))
+    assert(getData(provider, snapshotVersion - 1) === Set("a" -> 
(snapshotVersion - 1)))
     corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
 
     // Delete delta file and verify that it throws error
     deleteFilesEarlierThanVersion(provider, snapshotVersion)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
   }
 
   test("StateStore.get") {
     quietly {
-      val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+      val dir = newDir()
       val storeId = StateStoreId(dir, 0, 0)
       val storeConf = StateStoreConf.empty
       val hadoopConf = new Configuration()
 
-
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
-        StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, 
hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt 
to load the store
 
       intercept[IllegalStateException] {
-        StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, 
hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       }
 
-      // Increase version of the store
-      val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, 
storeConf, hadoopConf)
+      // Increase version of the store and try to get again
+      val store0 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
       assert(store0.version === 0)
       put(store0, "a", 1)
       store0.commit()
 
-      assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, 
hadoopConf).version == 1)
-      assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, 
hadoopConf).version == 0)
+      val store1 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+      assert(StateStore.isLoaded(storeId))
+      assert(store1.version === 1)
+      assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+
+      // Verify that you can also load older version
+      val store0reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+      assert(store0reloaded.version === 0)
+      assert(rowsToSet(store0reloaded.iterator()) === Set.empty)
 
       // Verify that you can remove the store and still reload and use it
       StateStore.unload(storeId)
       assert(!StateStore.isLoaded(storeId))
 
-      val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, 
storeConf, hadoopConf)
+      val store1reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
-      put(store1, "a", 2)
-      assert(store1.commit() === 2)
-      assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
+      assert(store1reloaded.version === 1)
+      put(store1reloaded, "a", 2)
+      assert(store1reloaded.commit() === 2)
+      assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2))
     }
   }
 
@@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
       // fails to talk to the StateStoreCoordinator and unloads all the 
StateStores
       .set("spark.rpc.numRetries", "1")
     val opId = 0
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, opId, 0)
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
     val storeConf = StateStoreConf(sqlConf)
     val hadoopConf = new Configuration()
-    val provider = new HDFSBackedStateStoreProvider(
-      storeId, keySchema, valueSchema, storeConf, hadoopConf)
+    val provider = newStoreProvider(storeId)
 
     var latestStoreVersion = 0
 
     def generateStoreVersions() {
       for (i <- 1 to 20) {
-        val store = StateStore.get(
-          storeId, keySchema, valueSchema, latestStoreVersion, storeConf, 
hadoopConf)
+        val store = StateStore.get(storeId, keySchema, valueSchema, None,
+          latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", i)
         store.commit()
         latestStoreVersion += 1
@@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, 
storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
 
           // If some other executor loads the store, then this instance should 
be unloaded
@@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, 
storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
         }
       }
@@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
   test("SPARK-18342: commit fails when rename fails") {
     import RenameReturnsFalseFileSystem._
-    val dir = scheme + "://" + Utils.createDirectory(tempDir, 
Random.nextString(5)).toURI.getPath
+    val dir = scheme + "://" + newDir()
     val conf = new Configuration()
     conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName)
-    val provider = newStoreProvider(dir = dir, hadoopConf = conf)
+    val provider = newStoreProvider(
+      opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf)
     val store = provider.getStore(0)
     put(store, "a", 0)
     val e = intercept[IllegalStateException](store.commit())
@@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
   }
 
   test("SPARK-18416: do not create temp delta file until the store is 
updated") {
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, 0, 0)
     val storeConf = StateStoreConf.empty
     val hadoopConf = new Configuration()
@@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, 
storeConf, hadoopConf)
     }
 
     // Put should create a temp file
@@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, 
storeConf, hadoopConf)
     }
     remove(store1, _ == "a")
     assert(numTempFiles === 1)
@@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
 
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, 
storeConf, hadoopConf)
     }
     store2.commit()
     assert(numTempFiles === 0)
     assert(numDeltaFiles === 3)
   }
 
-  def getDataFromFiles(
-      provider: HDFSBackedStateStoreProvider,
+  override def newStoreProvider(): HDFSBackedStateStoreProvider = {
+    newStoreProvider(opId = Random.nextInt(), partition = 0)
+  }
+
+  override def newStoreProvider(storeId: StateStoreId): 
HDFSBackedStateStoreProvider = {
+    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = 
storeId.checkpointLocation)
+  }
+
+  override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): 
Set[(String, Int)] = {
+    getData(storeProvider)
+  }
+
+  override def getData(
+    provider: HDFSBackedStateStoreProvider,
     version: Int = -1): Set[(String, Int)] = {
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      provider.id, keySchema, valueSchema, StateStoreConf.empty, new 
Configuration)
+    val reloadedProvider = newStoreProvider(provider.id)
     if (version < 0) {
       reloadedProvider.latestIterator().map(rowsToStringInt).toSet
     } else {
-      reloadedProvider.iterator(version).map(rowsToStringInt).toSet
+      reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet
     }
   }
 
-  def assertMap(
-      testMapOption: Option[MapType],
-      expectedMap: Map[String, Int]): Unit = {
-    assert(testMapOption.nonEmpty, "no map present")
-    val convertedMap = testMapOption.get.map(rowsToStringInt)
-    assert(convertedMap === expectedMap)
+  def newStoreProvider(
+      opId: Long,
+      partition: Int,
+      dir: String = newDir(),
+      minDeltasForSnapshot: Int = 
SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      hadoopConf: Configuration = new Configuration): 
HDFSBackedStateStoreProvider = {
+    val sqlConf = new SQLConf()
+    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, 
minDeltasForSnapshot)
+    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+    val provider = new HDFSBackedStateStoreProvider()
+    provider.init(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      indexOrdinal = None,
+      new StateStoreConf(sqlConf),
+      hadoopConf)
+    provider
   }
 
   def fileExists(
@@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with 
BeforeAndAfter with PrivateMeth
     filePath.delete()
     filePath.createNewFile()
   }
+}
 
-  def storeLoaded(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod[mutable.HashMap[StateStoreId, 
StateStore]]('loadedStores)
-    val loadedStores = StateStore invokePrivate method()
-    loadedStores.contains(storeId)
-  }
+abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
+  extends SparkFunSuite {
+  import StateStoreTestsHelper._
 
-  def unloadStore(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod('remove)
-    StateStore invokePrivate method(storeId)
-  }
+  test("get, put, remove, commit, and all data iterator") {
+    val provider = newStoreProvider()
 
-  def newStoreProvider(
-      opId: Long = Random.nextLong,
-      partition: Int = 0,
-      minDeltasForSnapshot: Int = 
SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
-      dir: String = Utils.createDirectory(tempDir, 
Random.nextString(5)).toString,
-      hadoopConf: Configuration = new Configuration()
-    ): HDFSBackedStateStoreProvider = {
-    val sqlConf = new SQLConf()
-    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, 
minDeltasForSnapshot)
-    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
-    new HDFSBackedStateStoreProvider(
-      StateStoreId(dir, opId, partition),
-      keySchema,
-      valueSchema,
-      new StateStoreConf(sqlConf),
-      hadoopConf)
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+
+    val store = provider.getStore(0)
+    assert(!store.hasCommitted)
+    assert(get(store, "a") === None)
+    assert(store.iterator().isEmpty)
+    assert(store.numKeys() === 0)
+
+    // Verify state after updating
+    put(store, "a", 1)
+    assert(get(store, "a") === Some(1))
+    assert(store.numKeys() === 1)
+
+    assert(store.iterator().nonEmpty)
+    assert(getLatestData(provider).isEmpty)
+
+    // Make updates, commit and then verify state
+    put(store, "b", 2)
+    put(store, "aa", 3)
+    assert(store.numKeys() === 3)
+    remove(store, _.startsWith("a"))
+    assert(store.numKeys() === 1)
+    assert(store.commit() === 1)
+
+    assert(store.hasCommitted)
+    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+    assert(getLatestData(provider) === Set("b" -> 2))
+
+    // Trying to get newer versions should fail
+    intercept[Exception] {
+      provider.getStore(2)
+    }
+    intercept[Exception] {
+      getData(provider, 2)
+    }
+
+    // New updates to the reloaded store with new version, and does not change 
old version
+    val reloadedProvider = newStoreProvider(store.id)
+    val reloadedStore = reloadedProvider.getStore(1)
+    assert(reloadedStore.numKeys() === 1)
+    put(reloadedStore, "c", 4)
+    assert(reloadedStore.numKeys() === 2)
+    assert(reloadedStore.commit() === 2)
+    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+    assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
+    assert(getData(provider, version = 1) === Set("b" -> 2))
   }
 
-  def remove(store: StateStore, condition: String => Boolean): Unit = {
-    store.remove(row => condition(rowToString(row)))
+  test("removing while iterating") {
+    val provider = newStoreProvider()
+
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    put(store, "b", 2)
+
+    // Updates should work while iterating of filtered entries
+    val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == 
"a" }
+    filtered.foreach { tuple =>
+      store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1))
+    }
+    assert(get(store, "a") === Some(2))
+
+    // Removes should work while iterating of filtered entries
+    val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == 
"b" }
+    filtered2.foreach { tuple => store.remove(tuple.key) }
+    assert(get(store, "b") === None)
   }
 
-  private def put(store: StateStore, key: String, value: Int): Unit = {
-    store.put(stringToRow(key), intToRow(value))
+  test("abort") {
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    store.commit()
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    // cancelUpdates should not change the data in the files
+    val store1 = provider.getStore(1)
+    put(store1, "b", 1)
+    store1.abort()
   }
 
-  private def get(store: StateStore, key: String): Option[Int] = {
-    store.get(stringToRow(key)).map(rowToInt)
+  test("getStore with invalid versions") {
+    val provider = newStoreProvider()
+
+    def checkInvalidVersion(version: Int): Unit = {
+      intercept[Exception] {
+        provider.getStore(version)
+      }
+    }
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(1)
+
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    assert(store.commit() === 1)
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    val store1_ = provider.getStore(1)
+    assert(rowsToSet(store1_.iterator()) === Set("a" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(2)
+
+    // Update store version with some data
+    val store1 = provider.getStore(1)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+    put(store1, "b", 1)
+    assert(store1.commit() === 2)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(3)
   }
-}
 
-private[state] object StateStoreSuite {
+  /** Return a new provider with a random id */
+  def newStoreProvider(): ProviderClass
+
+  /** Return a new provider with the given id */
+  def newStoreProvider(storeId: StateStoreId): ProviderClass
+
+  /** Get the latest data referred to by the given provider but not using this 
provider */
+  def getLatestData(storeProvider: ProviderClass): Set[(String, Int)]
+
+  /**
+   * Get a specific version of data referred to by the given provider but not 
using
+   * this provider
+   */
+  def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)]
+}
 
-  /** Trait and classes mirroring [[StoreUpdate]] for testing store updates 
iterator */
-  trait TestUpdate
-  case class Added(key: String, value: Int) extends TestUpdate
-  case class Updated(key: String, value: Int) extends TestUpdate
-  case class Removed(key: String) extends TestUpdate
+object StateStoreTestsHelper {
 
   val strProj = UnsafeProjection.create(Array[DataType](StringType))
   val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
@@ -692,26 +649,29 @@ private[state] object StateStoreSuite {
     row.getInt(0)
   }
 
-  def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = {
-    (rowToInt(row._1), rowToInt(row._2))
+  def rowsToStringInt(row: UnsafeRowPair): (String, Int) = {
+    (rowToString(row.key), rowToInt(row.value))
   }
 
+  def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = {
+    iterator.map(rowsToStringInt).toSet
+  }
 
-  def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = {
-    (rowToString(row._1), rowToInt(row._2))
+  def remove(store: StateStore, condition: String => Boolean): Unit = {
+    store.getRange(None, None).foreach { rowPair =>
+      if (condition(rowToString(rowPair.key))) store.remove(rowPair.key)
+    }
   }
 
-  def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, 
Int)] = {
-    iterator.map(rowsToStringInt).toSet
+  def put(store: StateStore, key: String, value: Int): Unit = {
+    store.put(stringToRow(key), intToRow(value))
   }
 
-  def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = {
-    iterator.map {
-      case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value))
-      case ValueUpdated(key, value) => Updated(rowToString(key), 
rowToInt(value))
-      case ValueRemoved(key, _) => Removed(rowToString(key))
-    }.toSet
+  def get(store: StateStore, key: String): Option[Int] = {
+    Option(store.get(stringToRow(key))).map(rowToInt)
   }
+
+  def newDir(): String = Utils.createTempDir().toString
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6bb9408..0d9ca81 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -31,7 +31,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, 
GroupStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreId, StoreUpdate}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreId, UnsafeRowPair}
 import 
org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
     expectedState = Some(5),                                  // state should 
change
     expectedTimeoutTimestamp = 5000)                          // timestamp 
should change
 
-  test("StateStoreUpdater - rows are cloned before writing to StateStore") {
-    // function for running count
-    val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
-      state.update(state.getOption.getOrElse(0) + values.size)
-      Iterator.empty
-    }
-    val store = newStateStore()
-    val plan = newFlatMapGroupsWithStateExec(func)
-    val updater = new plan.StateStoreUpdater(store)
-    val data = Seq(1, 1, 2)
-    val returnIter = 
updater.updateStateForKeysWithData(data.iterator.map(intToRow))
-    returnIter.size // consume the iterator to force store updates
-    val storeData = store.iterator.map { case (k, v) => (rowToInt(k), 
rowToInt(v)) }.toSet
-    assert(storeData === Set((1, 2), (2, 1)))
-  }
-
   test("flatMapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count if state is defined, otherwise does not 
return anything
@@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest with BeforeAndAf
       callFunction()
       val updatedStateRow = store.get(key)
       assert(
-        updater.getStateObj(updatedStateRow).map(_.toString.toInt) === 
expectedState,
+        Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === 
expectedState,
         "final state not as expected")
-      if (updatedStateRow.nonEmpty) {
+      if (updatedStateRow != null) {
         assert(
-          updater.getTimeoutTimestamp(updatedStateRow.get) === 
expectedTimeoutTimestamp,
+          updater.getTimeoutTimestamp(updatedStateRow) === 
expectedTimeoutTimestamp,
           "final timeout timestamp not as expected")
       }
     }
@@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite {
     import scala.collection.JavaConverters._
     private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
 
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) }
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      map.entrySet.iterator.asScala.map { case e => new 
UnsafeRowPair(e.getKey, e.getValue) }
     }
 
-    override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): 
Iterator[(UnsafeRow, UnsafeRow)] = {
-      iterator.filter { case (k, v) => c(k, v) }
+    override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
+      map.put(key.copy(), newValue.copy())
     }
-
-    override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key))
-    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, 
newValue)
     override def remove(key: UnsafeRow): Unit = { map.remove(key) }
-    override def remove(condition: (UnsafeRow) => Boolean): Unit = {
-      iterator.map(_._1).filter(condition).foreach(map.remove)
-    }
     override def commit(): Long = version + 1
     override def abort(): Unit = { }
     override def id: StateStoreId = null
     override def version: Long = 0
-    override def updates(): Iterator[StoreUpdate] = { throw new 
UnsupportedOperationException }
     override def numKeys(): Long = map.size
     override def hasCommitted: Boolean = true
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fa757ee1/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 1fc0629..280f2dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
 import scala.util.control.ControlThrowable
 
 import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -31,6 +32,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -614,6 +616,30 @@ class StreamSuite extends StreamTest {
     assertDescContainsQueryNameAnd(batch = 2)
     query.stop()
   }
+
+  testQuietly("specify custom state store provider") {
+    val queryName = "memStream"
+    val providerClassName = classOf[TestStateStoreProvider].getCanonicalName
+    withSQLConf("spark.sql.streaming.stateStore.providerClass" -> 
providerClassName) {
+      val input = MemoryStream[Int]
+      val query = input
+        .toDS()
+        .groupBy()
+        .count()
+        .writeStream
+        .outputMode("complete")
+        .format("memory")
+        .queryName(queryName)
+        .start()
+      input.addData(1, 2, 3)
+      val e = intercept[Exception] {
+        query.awaitTermination()
+      }
+
+      assert(e.getMessage.contains(providerClassName))
+      assert(e.getMessage.contains("instantiated"))
+    }
+  }
 }
 
 abstract class FakeSource extends StreamSourceProvider {
@@ -719,3 +745,22 @@ object ThrowingInterruptedIOException {
    */
   @volatile var createSourceLatch: CountDownLatch = null
 }
+
+class TestStateStoreProvider extends StateStoreProvider {
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int],
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    throw new Exception("Successfully instantiated")
+  }
+
+  override def id: StateStoreId = null
+
+  override def close(): Unit = { }
+
+  override def getStore(version: Long): StateStore = null
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to