Repository: spark
Updated Branches:
  refs/heads/master 1ebe7ffe0 -> 2ebd0838d


[SPARK-21192][SS] Preserve State Store provider class configuration across 
StreamingQuery restarts

## What changes were proposed in this pull request?

If the SQL conf for StateStore provider class is changed between restarts (i.e. 
query started with providerClass1 and attempted to restart using 
providerClass2), then the query will fail in a unpredictable way as files saved 
by one provider class cannot be used by the newer one.

Ideally, the provider class used to start the query should be used to restart 
the query, and the configuration in the session where it is being restarted 
should be ignored.

This PR saves the provider class config to OffsetSeqLog, in the same way # 
shuffle partitions is saved and recovered.

## How was this patch tested?
new unit tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #18402 from tdas/SPARK-21192.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ebd0838
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ebd0838
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ebd0838

Branch: refs/heads/master
Commit: 2ebd0838d165fe33b404e8d86c0fa445d1f47439
Parents: 1ebe7ff
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Fri Jun 23 10:55:02 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Fri Jun 23 10:55:02 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/internal/SQLConf.scala |  5 +-
 .../sql/execution/streaming/OffsetSeq.scala     | 39 ++++++++++++++-
 .../execution/streaming/StreamExecution.scala   | 26 +++-------
 .../execution/streaming/state/StateStore.scala  |  3 +-
 .../streaming/state/StateStoreConf.scala        |  2 +-
 .../execution/streaming/OffsetSeqLogSuite.scala | 10 ++--
 .../spark/sql/streaming/StreamSuite.scala       | 51 ++++++++++++++++----
 7 files changed, 96 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e609256..9c8e26a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -601,7 +601,8 @@ object SQLConf {
         "The class used to manage state data in stateful streaming queries. 
This class must " +
           "be a subclass of StateStoreProvider, and must have a zero-arg 
constructor.")
       .stringConf
-      .createOptional
+      .createWithDefault(
+        
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider")
 
   val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
     buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
@@ -897,7 +898,7 @@ class SQLConf extends Serializable with Logging {
 
   def optimizerInSetConversionThreshold: Int = 
getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)
 
-  def stateStoreProviderClass: Option[String] = 
getConf(STATE_STORE_PROVIDER_CLASS)
+  def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
 
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 8249ada..4e0a468 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -20,6 +20,10 @@ package org.apache.spark.sql.execution.streaming
 import org.json4s.NoTypeHints
 import org.json4s.jackson.Serialization
 
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.RuntimeConfig
+import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, 
STATE_STORE_PROVIDER_CLASS}
+
 /**
  * An ordered collection of offsets, used to track the progress of processing 
data from one or more
  * [[Source]]s that are present in a streaming query. This is similar to 
simplified, single-instance
@@ -78,7 +82,40 @@ case class OffsetSeqMetadata(
   def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
 }
 
-object OffsetSeqMetadata {
+object OffsetSeqMetadata extends Logging {
   private implicit val format = Serialization.formats(NoTypeHints)
+  private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, 
STATE_STORE_PROVIDER_CLASS)
+
   def apply(json: String): OffsetSeqMetadata = 
Serialization.read[OffsetSeqMetadata](json)
+
+  def apply(
+      batchWatermarkMs: Long,
+      batchTimestampMs: Long,
+      sessionConf: RuntimeConfig): OffsetSeqMetadata = {
+    val confs = relevantSQLConfs.map { conf => conf.key -> 
sessionConf.get(conf.key) }.toMap
+    OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs)
+  }
+
+  /** Set the SparkSession configuration with the values in the metadata */
+  def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): 
Unit = {
+    OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey =>
+
+      metadata.conf.get(confKey) match {
+
+        case Some(valueInMetadata) =>
+          // Config value exists in the metadata, update the session config 
with this value
+          val optionalValueInSession = sessionConf.getOption(confKey)
+          if (optionalValueInSession.isDefined && optionalValueInSession.get 
!= valueInMetadata) {
+            logWarning(s"Updating the value of conf '$confKey' in current 
session from " +
+              s"'${optionalValueInSession.get}' to '$valueInMetadata'.")
+          }
+          sessionConf.set(confKey, valueInMetadata)
+
+        case None =>
+          // For backward compatibility, if a config was not recorded in the 
offset log,
+          // then log it, and let the existing conf value in SparkSession 
prevail.
+          logWarning (s"Conf '$confKey' was not found in the offset log, using 
existing value")
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 06bdec8..d5f8d2a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -125,9 +125,8 @@ class StreamExecution(
   }
 
   /** Metadata associated with the offset seq of a batch in the query. */
-  protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, 
batchTimestampMs = 0,
-    conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
-      sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString))
+  protected var offsetSeqMetadata = OffsetSeqMetadata(
+    batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf)
 
   override val id: UUID = UUID.fromString(streamMetadata.id)
 
@@ -285,9 +284,8 @@ class StreamExecution(
       val sparkSessionToRunBatches = sparkSession.cloneSession()
       // Adaptive execution can change num shuffle partitions, disallow
       
sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, 
"false")
-      offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, 
batchTimestampMs = 0,
-        conf = Map(SQLConf.SHUFFLE_PARTITIONS.key ->
-          sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)))
+      offsetSeqMetadata = OffsetSeqMetadata(
+        batchWatermarkMs = 0, batchTimestampMs = 0, 
sparkSessionToRunBatches.conf)
 
       if (state.compareAndSet(INITIALIZING, ACTIVE)) {
         // Unblock `awaitInitialization`
@@ -441,21 +439,9 @@ class StreamExecution(
 
         // update offset metadata
         nextOffsets.metadata.foreach { metadata =>
-          val shufflePartitionsSparkSession: Int =
-            sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS)
-          val shufflePartitionsToUse = 
metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, {
-            // For backward compatibility, if # partitions was not recorded in 
the offset log,
-            // then ensure it is not missing. The new value is picked up from 
the conf.
-            logWarning("Number of shuffle partitions from previous run not 
found in checkpoint. "
-              + s"Using the value from the conf, 
$shufflePartitionsSparkSession partitions.")
-            shufflePartitionsSparkSession
-          })
+          OffsetSeqMetadata.setSessionConf(metadata, 
sparkSessionToRunBatches.conf)
           offsetSeqMetadata = OffsetSeqMetadata(
-            metadata.batchWatermarkMs, metadata.batchTimestampMs,
-            metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> 
shufflePartitionsToUse.toString))
-          // Update conf with correct number of shuffle partitions
-          sparkSessionToRunBatches.conf.set(
-            SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString)
+            metadata.batchWatermarkMs, metadata.batchTimestampMs, 
sparkSessionToRunBatches.conf)
         }
 
         /* identify the current batch id: if commit log indicates we 
successfully processed the

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index a94ff8a..8688646 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -172,8 +172,7 @@ object StateStoreProvider {
       indexOrdinal: Option[Int], // for sorting the data
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
-    val providerClass = storeConf.providerClass.map(Utils.classForName)
-        .getOrElse(classOf[HDFSBackedStateStoreProvider])
+    val providerClass = Utils.classForName(storeConf.providerClass)
     val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
     provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, 
storeConf, hadoopConf)
     provider

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index bab297c..765ff07 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -38,7 +38,7 @@ class StateStoreConf(@transient private val sqlConf: SQLConf)
    * Optional fully qualified name of the subclass of [[StateStoreProvider]]
    * managing state data. That is, the implementation of the State Store to 
use.
    */
-  val providerClass: Option[String] = sqlConf.stateStoreProviderClass
+  val providerClass: String = sqlConf.stateStoreProviderClass
 
   /**
    * Additional configurations related to state store. This will capture all 
configs in

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
index dc55632..e6cdc06 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
@@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with 
SharedSQLContext {
     }
 
     // None set
-    assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}"""))
+    assert(new OffsetSeqMetadata(0, 0, Map.empty) === 
OffsetSeqMetadata("""{}"""))
 
     // One set
-    assert(OffsetSeqMetadata(1, 0, Map.empty) === 
OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
-    assert(OffsetSeqMetadata(0, 2, Map.empty) === 
OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
+    assert(new OffsetSeqMetadata(1, 0, Map.empty) ===
+      OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
+    assert(new OffsetSeqMetadata(0, 2, Map.empty) ===
+      OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
     assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) ===
       OffsetSeqMetadata(s"""{"conf": {"$key":2}}"""))
 
     // Two set
-    assert(OffsetSeqMetadata(1, 2, Map.empty) ===
+    assert(new OffsetSeqMetadata(1, 2, Map.empty) ===
       OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
     assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) ===
       OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}"""))

http://git-wip-us.apache.org/repos/asf/spark/blob/2ebd0838/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 86c3a35..6f7b9d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -637,19 +637,11 @@ class StreamSuite extends StreamTest {
   }
 
   testQuietly("specify custom state store provider") {
-    val queryName = "memStream"
     val providerClassName = classOf[TestStateStoreProvider].getCanonicalName
     withSQLConf("spark.sql.streaming.stateStore.providerClass" -> 
providerClassName) {
       val input = MemoryStream[Int]
-      val query = input
-        .toDS()
-        .groupBy()
-        .count()
-        .writeStream
-        .outputMode("complete")
-        .format("memory")
-        .queryName(queryName)
-        .start()
+      val df = input.toDS().groupBy().count()
+      val query = 
df.writeStream.outputMode("complete").format("memory").queryName("name").start()
       input.addData(1, 2, 3)
       val e = intercept[Exception] {
         query.awaitTermination()
@@ -659,6 +651,45 @@ class StreamSuite extends StreamTest {
       assert(e.getMessage.contains("instantiated"))
     }
   }
+
+  testQuietly("custom state store provider read from offset log") {
+    val input = MemoryStream[Int]
+    val df = input.toDS().groupBy().count()
+    val providerConf1 = "spark.sql.streaming.stateStore.providerClass" ->
+      
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"
+    val providerConf2 = "spark.sql.streaming.stateStore.providerClass" ->
+      classOf[TestStateStoreProvider].getCanonicalName
+
+    def runQuery(queryName: String, checkpointLoc: String): Unit = {
+      val query = df.writeStream
+        .outputMode("complete")
+        .format("memory")
+        .queryName(queryName)
+        .option("checkpointLocation", checkpointLoc)
+        .start()
+      input.addData(1, 2, 3)
+      query.processAllAvailable()
+      query.stop()
+    }
+
+    withTempDir { dir =>
+      val checkpointLoc1 = new File(dir, "1").getCanonicalPath
+      withSQLConf(providerConf1) {
+        runQuery("query1", checkpointLoc1)  // generate checkpoints
+      }
+
+      val checkpointLoc2 = new File(dir, "2").getCanonicalPath
+      withSQLConf(providerConf2) {
+        // Verify new query will use new provider that throw error on loading
+        intercept[Exception] {
+          runQuery("query2", checkpointLoc2)
+        }
+
+        // Verify old query from checkpoint will still use old provider
+        runQuery("query1", checkpointLoc1)
+      }
+    }
+  }
 }
 
 abstract class FakeSource extends StreamSourceProvider {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to