This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3257a30  [SPARK-35784][SS] Implementation for RocksDB instance
3257a30 is described below

commit 3257a30e5399d4f366e4aae60b04371b31514fb4
Author: Yuanjian Li <yuanjian...@databricks.com>
AuthorDate: Tue Jun 29 17:46:45 2021 -0700

    [SPARK-35784][SS] Implementation for RocksDB instance
    
    ### What changes were proposed in this pull request?
    The implementation for the RocksDB instance, which is used in the RocksDB 
state store. It plays a role as a handler for the RocksDB instance and 
RocksDBFileManager.
    
    ### Why are the changes needed?
    Part of the RocksDB state store implementation.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New UT added.
    
    Closes #32928 from xuanyuanking/SPARK-35784.
    
    Authored-by: Yuanjian Li <yuanjian...@databricks.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 dev/deps/spark-deps-hadoop-2.7-hive-2.3            |   1 +
 dev/deps/spark-deps-hadoop-3.2-hive-2.3            |   1 +
 sql/core/pom.xml                                   |   5 +
 .../sql/execution/streaming/state/RocksDB.scala    | 452 +++++++++++++++++++++
 .../execution/streaming/state/RocksDBLoader.scala  |  60 +++
 .../execution/streaming/state/RocksDBSuite.scala   |  97 +++++
 6 files changed, 616 insertions(+)

diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 
b/dev/deps/spark-deps-hadoop-2.7-hive-2.3
index b18df07..0d8e032 100644
--- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3
@@ -211,6 +211,7 @@ parquet-jackson/1.12.0//parquet-jackson-1.12.0.jar
 protobuf-java/2.5.0//protobuf-java-2.5.0.jar
 py4j/0.10.9.2//py4j-0.10.9.2.jar
 pyrolite/4.30//pyrolite-4.30.jar
+rocksdbjni/6.2.2//rocksdbjni-6.2.2.jar
 scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar
 scala-compiler/2.12.14//scala-compiler-2.12.14.jar
 scala-library/2.12.14//scala-library-2.12.14.jar
diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 
b/dev/deps/spark-deps-hadoop-3.2-hive-2.3
index bc77aa6..b7d4938 100644
--- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3
@@ -182,6 +182,7 @@ parquet-jackson/1.12.0//parquet-jackson-1.12.0.jar
 protobuf-java/2.5.0//protobuf-java-2.5.0.jar
 py4j/0.10.9.2//py4j-0.10.9.2.jar
 pyrolite/4.30//pyrolite-4.30.jar
+rocksdbjni/6.2.2//rocksdbjni-6.2.2.jar
 scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar
 scala-compiler/2.12.14//scala-compiler-2.12.14.jar
 scala-library/2.12.14//scala-library-2.12.14.jar
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index f7bbe80..149b58b 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -36,6 +36,11 @@
 
   <dependencies>
     <dependency>
+      <groupId>org.rocksdb</groupId>
+      <artifactId>rocksdbjni</artifactId>
+      <version>6.2.2</version>
+    </dependency>
+    <dependency>
       <groupId>com.univocity</groupId>
       <artifactId>univocity-parsers</artifactId>
       <type>jar</type>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
new file mode 100644
index 0000000..82aa166
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+import java.util.Locale
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.{mutable, Map}
+import scala.ref.WeakReference
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+import org.rocksdb.{RocksDB => NativeRocksDB, _}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.util.{NextIterator, Utils}
+
+/**
+ * Class representing a RocksDB instance that checkpoints version of data to 
DFS.
+ * After a set of updates, a new version can be committed by calling 
`commit()`.
+ * Any past version can be loaded by calling `load(version)`.
+ *
+ * @note This class is not thread-safe, so use it only from one thread.
+ * @see [[RocksDBFileManager]] to see how the files are laid out in local disk 
and DFS.
+ * @param dfsRootDir  Remote directory where checkpoints are going to be 
written
+ * @param conf         Configuration for RocksDB
+ * @param localRootDir Root directory in local disk that is used to working 
and checkpointing dirs
+ * @param hadoopConf   Hadoop configuration for talking to the remote file 
system
+ * @param loggingId    Id that will be prepended in logs for isolating 
concurrent RocksDBs
+ */
+class RocksDB(
+    dfsRootDir: String,
+    val conf: RocksDBConf,
+    localRootDir: File = Utils.createTempDir(),
+    hadoopConf: Configuration = new Configuration,
+    loggingId: String = "") extends Logging {
+
+  RocksDBLoader.loadLibrary()
+
+  // Java wrapper objects linking to native RocksDB objects
+  private val readOptions = new ReadOptions()  // used for gets
+  private val writeOptions = new WriteOptions().setSync(true)  // wait for 
batched write to complete
+  private val flushOptions = new FlushOptions().setWaitForFlush(true)  // wait 
for flush to complete
+  private val writeBatch = new WriteBatchWithIndex(true)  // overwrite 
multiple updates to a key
+
+  private val bloomFilter = new BloomFilter()
+  private val tableFormatConfig = new BlockBasedTableConfig()
+  tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+  tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 
1024))
+  tableFormatConfig.setFilterPolicy(bloomFilter)
+
+  private val dbOptions = new Options() // options to open the RocksDB
+  dbOptions.setCreateIfMissing(true)
+  dbOptions.setTableFormatConfig(tableFormatConfig)
+  private val dbLogger = createLogger() // for forwarding RocksDB native logs 
to log4j
+  dbOptions.setStatistics(new Statistics())
+
+  private val workingDir = createTempDir("workingDir")
+  private val fileManager = new RocksDBFileManager(
+    dfsRootDir, createTempDir("fileManager"), hadoopConf, loggingId = 
loggingId)
+  private val byteArrayPair = new ByteArrayPair()
+  private val commitLatencyMs = new mutable.HashMap[String, Long]()
+  private val acquireLock = new Object
+
+  @volatile private var db: NativeRocksDB = _
+  @volatile private var loadedVersion = -1L   // -1 = nothing valid is loaded
+  @volatile private var numKeysOnLoadedVersion = 0L
+  @volatile private var numKeysOnWritingVersion = 0L
+
+  @GuardedBy("acquireLock")
+  @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
+
+  /**
+   * Load the given version of data in a native RocksDB instance.
+   * Note that this will copy all the necessary file from DFS to local disk as 
needed,
+   * and possibly restart the native RocksDB instance.
+   */
+  def load(version: Long): RocksDB = {
+    assert(version >= 0)
+    acquire()
+    logInfo(s"Loading $version")
+    try {
+      if (loadedVersion != version) {
+        closeDB()
+        val metadata = fileManager.loadCheckpointFromDfs(version, workingDir)
+        openDB()
+        numKeysOnWritingVersion = metadata.numKeys
+        numKeysOnLoadedVersion = metadata.numKeys
+        loadedVersion = version
+      }
+      writeBatch.clear()
+      logInfo(s"Loaded $version")
+    } catch {
+      case t: Throwable =>
+        loadedVersion = -1  // invalidate loaded data
+        throw t
+    }
+    this
+  }
+
+  /**
+   * Get the value for the given key if present, or null.
+   * @note This will return the last written value even if it was uncommitted.
+   */
+  def get(key: Array[Byte]): Array[Byte] = {
+    writeBatch.getFromBatchAndDB(db, readOptions, key)
+  }
+
+  /**
+   * Put the given value for the given key and return the last written value.
+   * @note This update is not committed to disk until commit() is called.
+   */
+  def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
+    val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    writeBatch.put(key, value)
+    if (oldValue == null) {
+      numKeysOnWritingVersion += 1
+    }
+    oldValue
+  }
+
+  /**
+   * Remove the key if present, and return the previous value if it was 
present (null otherwise).
+   * @note This update is not committed to disk until commit() is called.
+   */
+  def remove(key: Array[Byte]): Array[Byte] = {
+    val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    if (value != null) {
+      writeBatch.remove(key)
+      numKeysOnWritingVersion -= 1
+    }
+    value
+  }
+
+  /**
+   * Get an iterator of all committed and uncommitted key-value pairs.
+   */
+  def iterator(): Iterator[ByteArrayPair] = {
+    val iter = writeBatch.newIteratorWithBase(db.newIterator())
+    logInfo(s"Getting iterator from version $loadedVersion")
+    iter.seekToFirst()
+
+    // Attempt to close this iterator if there is a task failure, or a task 
interruption.
+    // This is a hack because it assumes that the RocksDB is running inside a 
task.
+    Option(TaskContext.get()).foreach { tc =>
+      tc.addTaskCompletionListener[Unit] { _ => iter.close() }
+    }
+
+    new NextIterator[ByteArrayPair] {
+      override protected def getNext(): ByteArrayPair = {
+        if (iter.isValid) {
+          byteArrayPair.set(iter.key, iter.value)
+          iter.next()
+          byteArrayPair
+        } else {
+          finished = true
+          iter.close()
+          null
+        }
+      }
+      override protected def close(): Unit = { iter.close() }
+    }
+  }
+
+  /**
+   * Commit all the updates made as a version to DFS. The steps it needs to do 
to commits are:
+   * - Write all the updates to the native RocksDB
+   * - Flush all changes to disk
+   * - Create a RocksDB checkpoint in a new local dir
+   * - Sync the checkpoint dir files to DFS
+   */
+  def commit(): Long = {
+    val newVersion = loadedVersion + 1
+    val checkpointDir = createTempDir("checkpoint")
+    try {
+      // Make sure the directory does not exist. Native RocksDB fails if the 
directory to
+      // checkpoint exists.
+      Utils.deleteRecursively(checkpointDir)
+
+      logInfo(s"Writing updates for $newVersion")
+      val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
+
+      logInfo(s"Flushing updates for $newVersion")
+      val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
+
+      val compactTimeMs = if (conf.compactOnCommit) {
+        logInfo("Compacting")
+        timeTakenMs { db.compactRange() }
+      } else 0
+      logInfo("Pausing background work")
+
+      val pauseTimeMs = timeTakenMs {
+        db.pauseBackgroundWork() // To avoid files being changed while 
committing
+      }
+
+      logInfo(s"Creating checkpoint for $newVersion in $checkpointDir")
+      val checkpointTimeMs = timeTakenMs {
+        val cp = Checkpoint.create(db)
+        cp.createCheckpoint(checkpointDir.toString)
+      }
+
+      logInfo(s"Syncing checkpoint for $newVersion to DFS")
+      val fileSyncTimeMs = timeTakenMs {
+        fileManager.saveCheckpointToDfs(checkpointDir, newVersion, 
numKeysOnWritingVersion)
+      }
+      numKeysOnLoadedVersion = numKeysOnWritingVersion
+      loadedVersion = newVersion
+      commitLatencyMs ++= Map(
+        "writeBatch" -> writeTimeMs,
+        "flush" -> flushTimeMs,
+        "compact" -> compactTimeMs,
+        "pause" -> pauseTimeMs,
+        "checkpoint" -> checkpointTimeMs,
+        "fileSync" -> fileSyncTimeMs
+      )
+      loadedVersion
+    } catch {
+      case t: Throwable =>
+        loadedVersion = -1  // invalidate loaded version
+        throw t
+    } finally {
+      db.continueBackgroundWork()
+      silentDeleteRecursively(checkpointDir, s"committing $newVersion")
+      release()
+    }
+  }
+
+  /**
+   * Drop uncommitted changes, and roll back to previous version.
+   */
+  def rollback(): Unit = {
+    writeBatch.clear()
+    numKeysOnWritingVersion = numKeysOnLoadedVersion
+    release()
+    logInfo(s"Rolled back to $loadedVersion")
+  }
+
+  /** Release all resources */
+  def close(): Unit = {
+    try {
+      closeDB()
+
+      // Release all resources related to native RockDB objects
+      writeBatch.clear()
+      writeBatch.close()
+      readOptions.close()
+      writeOptions.close()
+      flushOptions.close()
+      dbOptions.close()
+      dbLogger.close()
+      silentDeleteRecursively(localRootDir, "closing RocksDB")
+    } catch {
+      case e: Exception =>
+        logWarning("Error closing RocksDB", e)
+    }
+  }
+
+  /** Get the latest version available in the DFS */
+  def getLatestVersion(): Long = fileManager.getLatestVersion()
+
+  private def acquire(): Unit = acquireLock.synchronized {
+    val newAcquiredThreadInfo = AcquiredThreadInfo()
+    val waitStartTime = System.currentTimeMillis
+    def timeWaitedMs = System.currentTimeMillis - waitStartTime
+    def isAcquiredByDifferentThread = acquiredThreadInfo != null &&
+      acquiredThreadInfo.threadRef.get.isDefined &&
+      newAcquiredThreadInfo.threadRef.get.get.getId != 
acquiredThreadInfo.threadRef.get.get.getId
+
+    while (isAcquiredByDifferentThread && timeWaitedMs < 
conf.lockAcquireTimeoutMs) {
+      acquireLock.wait(10)
+    }
+    if (isAcquiredByDifferentThread) {
+      val stackTraceOutput = 
acquiredThreadInfo.threadRef.get.get.getStackTrace.mkString("\n")
+      val msg = s"RocksDB instance could not be acquired by 
$newAcquiredThreadInfo as it " +
+        s"was not released by $acquiredThreadInfo after $timeWaitedMs ms.\n" +
+        s"Thread holding the lock has trace: $stackTraceOutput"
+      logError(msg)
+      throw new IllegalStateException(s"$loggingId: $msg")
+    } else {
+      acquiredThreadInfo = newAcquiredThreadInfo
+      // Add a listener to always release the lock when the task (if active) 
completes
+      Option(TaskContext.get).foreach(_.addTaskCompletionListener[Unit] { _ => 
this.release() })
+      logInfo(s"RocksDB instance was acquired by $acquiredThreadInfo")
+    }
+  }
+
+  private def release(): Unit = acquireLock.synchronized {
+    acquiredThreadInfo = null
+    acquireLock.notifyAll()
+  }
+
+  private def openDB(): Unit = {
+    assert(db == null)
+    db = NativeRocksDB.open(dbOptions, workingDir.toString)
+    logInfo(s"Opened DB with conf ${conf}")
+  }
+
+  private def closeDB(): Unit = {
+    if (db != null) {
+      db.close()
+      db = null
+    }
+  }
+
+  /** Create a native RocksDB logger that forwards native logs to log4j with 
correct log levels. */
+  private def createLogger(): Logger = {
+    val dbLogger = new Logger(dbOptions) {
+      override def log(infoLogLevel: InfoLogLevel, logMsg: String) = {
+        // Map DB log level to log4j levels
+        // Warn is mapped to info because RocksDB warn is too verbose
+        // (e.g. dumps non-warning stuff like stats)
+        val loggingFunc: ( => String) => Unit = infoLogLevel match {
+          case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => 
logError(_)
+          case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
+          case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
+          case _ => logTrace(_)
+        }
+        loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
+      }
+    }
+
+    var dbLogLevel = InfoLogLevel.ERROR_LEVEL
+    if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL
+    if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL
+    if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL
+    dbOptions.setLogger(dbLogger)
+    dbOptions.setInfoLogLevel(dbLogLevel)
+    logInfo(s"Set RocksDB native logging level to $dbLogLevel")
+    dbLogger
+  }
+
+  /** Create a temp directory inside the local root directory */
+  private def createTempDir(prefix: String): File = {
+    Utils.createDirectory(localRootDir.getAbsolutePath, prefix)
+  }
+
+  /** Attempt to delete recursively, and log the error if any */
+  private def silentDeleteRecursively(file: File, msg: String): Unit = {
+    try {
+      Utils.deleteRecursively(file)
+    } catch {
+      case e: Exception =>
+        logWarning(s"Error recursively deleting local dir $file while $msg", e)
+    }
+  }
+
+  /** Records the duration of running `body` for the next query progress 
update. */
+  protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2
+
+  override protected def logName: String = s"${super.logName} $loggingId"
+}
+
+
+/** Mutable and reusable pair of byte arrays */
+class ByteArrayPair(var key: Array[Byte] = null, var value: Array[Byte] = 
null) {
+  def set(key: Array[Byte], value: Array[Byte]): ByteArrayPair = {
+    this.key = key
+    this.value = value
+    this
+  }
+}
+
+
+/**
+ * Configurations for optimizing RocksDB
+ *
+ * @param compactOnCommit Whether to compact RocksDB data before commit / 
checkpointing
+ */
+case class RocksDBConf(
+    minVersionsToRetain: Int,
+    compactOnCommit: Boolean,
+    pauseBackgroundWorkForCommit: Boolean,
+    blockSizeKB: Long,
+    blockCacheSizeMB: Long,
+    lockAcquireTimeoutMs: Long)
+
+object RocksDBConf {
+  /** Common prefix of all confs in SQLConf that affects RocksDB */
+  val ROCKSDB_CONF_NAME_PREFIX = "spark.sql.streaming.stateStore.rocksdb"
+
+  private case class ConfEntry(name: String, default: String) {
+    def fullName: String = 
s"$ROCKSDB_CONF_NAME_PREFIX.${name}".toLowerCase(Locale.ROOT)
+  }
+
+  // Configuration that specifies whether to compact the RocksDB data every 
time data is committed
+  private val COMPACT_ON_COMMIT_CONF = ConfEntry("compactOnCommit", "false")
+  private val PAUSE_BG_WORK_FOR_COMMIT_CONF = 
ConfEntry("pauseBackgroundWorkForCommit", "true")
+  private val BLOCK_SIZE_KB_CONF = ConfEntry("blockSizeKB", "4")
+  private val BLOCK_CACHE_SIZE_MB_CONF = ConfEntry("blockCacheSizeMB", "8")
+  private val LOCK_ACQUIRE_TIMEOUT_MS_CONF = ConfEntry("lockAcquireTimeoutMs", 
"60000")
+
+  def apply(storeConf: StateStoreConf): RocksDBConf = {
+    val confs = CaseInsensitiveMap[String](storeConf.confs)
+
+    def getBooleanConf(conf: ConfEntry): Boolean = {
+      Try { confs.getOrElse(conf.fullName, conf.default).toBoolean } getOrElse 
{
+        throw new IllegalArgumentException(s"Invalid value for 
'${conf.fullName}', must be boolean")
+      }
+    }
+
+    def getPositiveLongConf(conf: ConfEntry): Long = {
+      Try { confs.getOrElse(conf.fullName, conf.default).toLong } filter { _ 
>= 0 } getOrElse {
+        throw new IllegalArgumentException(
+          s"Invalid value for '${conf.fullName}', must be a positive integer")
+      }
+    }
+
+    RocksDBConf(
+      storeConf.minVersionsToRetain,
+      getBooleanConf(COMPACT_ON_COMMIT_CONF),
+      getBooleanConf(PAUSE_BG_WORK_FOR_COMMIT_CONF),
+      getPositiveLongConf(BLOCK_SIZE_KB_CONF),
+      getPositiveLongConf(BLOCK_CACHE_SIZE_MB_CONF),
+      getPositiveLongConf(LOCK_ACQUIRE_TIMEOUT_MS_CONF))
+  }
+
+  def apply(): RocksDBConf = apply(new StateStoreConf())
+}
+
+case class AcquiredThreadInfo() {
+  val threadRef: WeakReference[Thread] = new 
WeakReference[Thread](Thread.currentThread())
+  val tc: TaskContext = TaskContext.get()
+
+  override def toString(): String = {
+    val taskStr = if (tc != null) {
+      val taskDetails =
+        s"${tc.partitionId}.${tc.attemptNumber} in stage ${tc.stageId}, TID 
${tc.taskAttemptId}"
+      s", task: $taskDetails"
+    } else ""
+
+    s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
+  }
+}
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala
new file mode 100644
index 0000000..cc51819
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.rocksdb.{RocksDB => NativeRocksDB}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.UninterruptibleThread
+
+/**
+ * A wrapper for RocksDB library loading using an uninterruptible thread, as 
the native RocksDB
+ * code will throw an error when interrupted.
+ */
+object RocksDBLoader extends Logging {
+  /**
+   * Keep tracks of the exception thrown from the loading thread, if any.
+   */
+  private var exception: Option[Throwable] = null
+
+  private val loadLibraryThread = new UninterruptibleThread("RocksDBLoader") {
+    override def run(): Unit = {
+      try {
+        runUninterruptibly {
+          NativeRocksDB.loadLibrary()
+          exception = None
+        }
+      } catch {
+        case e: Throwable =>
+          exception = Some(e)
+      }
+    }
+  }
+
+  def loadLibrary(): Unit = synchronized {
+    if (exception == null) {
+      loadLibraryThread.start()
+      logInfo("RocksDB library loading thread started")
+      loadLibraryThread.join()
+      exception.foreach(throw _)
+      logInfo("RocksDB library loading thread finished successfully")
+    } else {
+      exception.foreach(throw _)
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index c75eed2..a11eb8a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -33,6 +33,75 @@ import org.apache.spark.util.Utils
 
 class RocksDBSuite extends SparkFunSuite {
 
+  test("RocksDB: get, put, iterator, commit, load") {
+    def testOps(compactOnCommit: Boolean): Unit = {
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()  // to make sure that the directory gets 
created
+
+      val conf = RocksDBConf().copy(compactOnCommit = compactOnCommit)
+      withDB(remoteDir, conf = conf) { db =>
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("a", "1")
+        assert(toStr(db.get("a")) === "1")
+        db.commit()
+      }
+
+      withDB(remoteDir, conf = conf, version = 0) { db =>
+        // version 0 can be loaded again
+        assert(toStr(db.get("a")) === null)
+        assert(iterator(db).isEmpty)
+      }
+
+      withDB(remoteDir, conf = conf, version = 1) { db =>
+        // version 1 data recovered correctly
+        assert(toStr(db.get("a")) === "1")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+
+        // make changes but do not commit version 2
+        db.put("b", "2")
+        assert(toStr(db.get("b")) === "2")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2")))
+      }
+
+      withDB(remoteDir, conf = conf, version = 1) { db =>
+        // version 1 data not changed
+        assert(toStr(db.get("a")) === "1")
+        assert(db.get("b") === null)
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+
+        // commit version 2
+        db.put("b", "2")
+        assert(toStr(db.get("b")) === "2")
+        db.commit()
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2")))
+      }
+
+      withDB(remoteDir, conf = conf, version = 1) { db =>
+        // version 1 data not changed
+        assert(toStr(db.get("a")) === "1")
+        assert(db.get("b") === null)
+      }
+
+      withDB(remoteDir, conf = conf, version = 2) { db =>
+        // version 2 can be loaded again
+        assert(toStr(db.get("b")) === "2")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2")))
+
+        db.load(1)
+        assert(toStr(db.get("b")) === null)
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+      }
+    }
+
+    for (compactOnCommit <- Seq(false, true)) {
+      withClue(s"compactOnCommit = $compactOnCommit") {
+        testOps(compactOnCommit)
+      }
+    }
+  }
+
   test("RocksDBFileManager: upload only new immutable files") {
     withTempDir { dir =>
       val dfsRootDir = dir.getAbsolutePath
@@ -167,6 +236,26 @@ class RocksDBSuite extends SparkFunSuite {
     // scalastyle:on line.size.limit
   }
 
+  def withDB[T](
+      remoteDir: String,
+      version: Int = 0,
+      conf: RocksDBConf = RocksDBConf().copy(compactOnCommit = false, 
minVersionsToRetain = 100),
+      hadoopConf: Configuration = new Configuration())(
+      func: RocksDB => T): T = {
+    var db: RocksDB = null
+    try {
+      db = new RocksDB(
+        remoteDir, conf = conf, hadoopConf = hadoopConf,
+        loggingId = s"[Thread-${Thread.currentThread.getId}]")
+      db.load(version)
+      func(db)
+    } finally {
+      if (db != null) {
+        db.close()
+      }
+    }
+  }
+
   def generateFiles(dir: String, fileToLengths: Seq[(String, Int)]): Unit = {
     fileToLengths.foreach { case (fileName, length) =>
       val file = new File(dir, fileName)
@@ -200,6 +289,14 @@ class RocksDBSuite extends SparkFunSuite {
 
   implicit def toFile(path: String): File = new File(path)
 
+  implicit def toArray(str: String): Array[Byte] = if (str != null) 
str.getBytes else null
+
+  implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new 
String(bytes) else null
+
+  def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), 
toStr(kv.value))
+
+  def iterator(db: RocksDB): Iterator[(String, String)] = 
db.iterator().map(toStr)
+
   def listFiles(file: File): Seq[File] = {
     if (!file.exists()) return Seq.empty
     file.listFiles.filter(file => !file.getName.endsWith("crc") && 
!file.isDirectory)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to