This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d72be3abdc4 [SPARK-47363][SS] Initial State without state reader 
implementation for State API v2
4d72be3abdc4 is described below

commit 4d72be3abdc4c651da029bdbd24a574099d45e7c
Author: jingz-db <jing.z...@databricks.com>
AuthorDate: Thu Mar 28 14:50:46 2024 +0900

    [SPARK-47363][SS] Initial State without state reader implementation for 
State API v2
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for users to provide a Dataframe that can be used to 
instantiate state for the query in the first batch for arbitrary state API v2.
    
    Note that populating the initial state will only happen for the first batch 
of the new streaming query. Trying to re-initialize state for the same grouping 
key will result in an error.
    
    ### Why are the changes needed?
    
    These changes are needed to support initial state. The changes are part of 
the work around adding new stateful streaming operator for arbitrary state mgmt 
that provides a bunch of new features listed in the SPIP JIRA here - 
https://issues.apache.org/jira/browse/SPARK-45939
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    This PR introduces a new function:
    ```
    def transformWithState(
          statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
          timeoutMode: TimeoutMode,
          outputMode: OutputMode,
          initialState: KeyValueGroupedDataset[K, S]): Dataset[U]
    ```
    
    ### How was this patch tested?
    
    Unit tests in `TransformWithStateWithInitialStateSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45467 from jingz-db/initial-state-state-v2.
    
    Lead-authored-by: jingz-db <jing.z...@databricks.com>
    Co-authored-by: Jing Zhan <135738831+jingz...@users.noreply.github.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |   6 +
 docs/sql-error-conditions.md                       |   6 +
 .../spark/sql/streaming/StatefulProcessor.scala    |  19 ++
 .../spark/sql/catalyst/plans/logical/object.scala  |  55 +++-
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  38 ++-
 .../spark/sql/execution/SparkStrategies.scala      |  20 +-
 .../execution/streaming/IncrementalExecution.scala |   4 +-
 .../streaming/TransformWithStateExec.scala         | 254 ++++++++++++++----
 .../streaming/state/StateStoreErrors.scala         |  10 +
 .../sql/streaming/TransformWithMapStateSuite.scala |   5 +-
 .../TransformWithStateInitialStateSuite.scala      | 293 +++++++++++++++++++++
 .../sql/streaming/TransformWithStateSuite.scala    |  20 ++
 12 files changed, 661 insertions(+), 69 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 185e86853dfd..11c8204d2c93 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3553,6 +3553,12 @@
     ],
     "sqlState" : "42802"
   },
+  "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
+    "message" : [
+      "Cannot re-initialize state on the same grouping key during initial 
state handling for stateful processor. Invalid grouping key=<groupingKey>."
+    ],
+    "sqlState" : "42802"
+  },
   "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
     "message" : [
       "Failed to create column family with unsupported starting character and 
name=<colFamilyName>."
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 838ca2fa33c9..85b9e85ac420 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2162,6 +2162,12 @@ Failed to perform stateful processor 
operation=`<operationType>` with invalid ha
 
 Failed to perform stateful processor operation=`<operationType>` with invalid 
timeoutMode=`<timeoutMode>`
 
+### STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY
+
+[SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Cannot re-initialize state on the same grouping key during initial state 
handling for stateful processor. Invalid grouping key=`<groupingKey>`.
+
 ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS
 
 [SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index ad9b807ddf5a..1a61972f0ed0 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -91,3 +91,22 @@ private[sql] trait StatefulProcessor[K, I, O] extends 
Serializable {
     statefulProcessorHandle
   }
 }
+
+/**
+ * Stateful processor with support for specifying initial state.
+ * Accepts a user-defined type as initial state to be initialized in the first 
batch.
+ * This can be used for starting a new streaming query with existing state 
from a
+ * previous streaming query.
+ */
+@Experimental
+@Evolving
+trait StatefulProcessorWithInitialState[K, I, O, S] extends 
StatefulProcessor[K, I, O] {
+
+  /**
+   * Function that will be invoked only in the first batch for users to 
process initial states.
+   *
+   * @param key - grouping key
+   * @param initialState - A row in the initial state to be processed
+   */
+  def handleInitialState(key: K, initialState: S): Unit
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index cb8673d20ed3..b2c443a8cce0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -588,7 +588,46 @@ object TransformWithState {
       outputMode,
       keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
       CatalystSerde.generateObjAttr[U],
-      child
+      child,
+      hasInitialState = false,
+      // the following parameters will not be used in physical plan if 
hasInitialState = false
+      initialStateGroupingAttrs = groupingAttributes,
+      initialStateDataAttrs = dataAttributes,
+      initialStateDeserializer =
+        UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+      initialState = LocalRelation(encoderFor[K].schema) // empty data set
+    )
+    CatalystSerde.serialize[U](mapped)
+  }
+
+  // This apply() is to invoke TransformWithState object with hasInitialState 
set to true
+  def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder](
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeoutMode: TimeoutMode,
+      outputMode: OutputMode,
+      child: LogicalPlan,
+      initialStateGroupingAttrs: Seq[Attribute],
+      initialStateDataAttrs: Seq[Attribute],
+      initialState: LogicalPlan): LogicalPlan = {
+    val keyEncoder = encoderFor[K]
+    val mapped = new TransformWithState(
+      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+      UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+      groupingAttributes,
+      dataAttributes,
+      statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
+      timeoutMode,
+      outputMode,
+      keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
+      CatalystSerde.generateObjAttr[U],
+      child,
+      hasInitialState = true,
+      initialStateGroupingAttrs,
+      initialStateDataAttrs,
+      UnresolvedDeserializer(encoderFor[S].deserializer, 
initialStateDataAttrs),
+      initialState
     )
     CatalystSerde.serialize[U](mapped)
   }
@@ -604,10 +643,18 @@ case class TransformWithState(
     outputMode: OutputMode,
     keyEncoder: ExpressionEncoder[Any],
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectProducer {
+    child: LogicalPlan,
+    hasInitialState: Boolean = false,
+    initialStateGroupingAttrs: Seq[Attribute],
+    initialStateDataAttrs: Seq[Attribute],
+    initialStateDeserializer: Expression,
+    initialState: LogicalPlan) extends BinaryNode with ObjectProducer {
 
-  override protected def withNewChildInternal(newChild: LogicalPlan): 
TransformWithState =
-    copy(child = newChild)
+  override def left: LogicalPlan = child
+  override def right: LogicalPlan = initialState
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
+    copy(child = newLeft, initialState = newRight)
 }
 
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 50ab2a41612b..95ad973aee51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.ReduceAggregator
 import org.apache.spark.sql.internal.TypedAggUtils
-import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, TimeoutMode}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}
 
 /**
  * A [[Dataset]] has been logically grouped by a user specified grouping key.  
Users should not
@@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
     )
   }
 
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * Functions as the function above, but with additional initial state.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark 
SQL types.
+   * @param statefulProcessor Instance of statefulProcessor whose functions 
will
+   *                          be invoked by the operator.
+   * @param timeoutMode       The timeout mode of the stateful processor.
+   * @param outputMode        The output mode of the stateful processor. 
Defaults to APPEND mode.
+   * @param initialState      User provided initial state that will be used to 
initiate state for
+   *                          the query in the first batch.
+   *
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeoutMode: TimeoutMode,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+    Dataset[U](
+      sparkSession,
+      TransformWithState[K, V, U, S](
+        groupingAttributes,
+        dataAttributes,
+        statefulProcessor,
+        timeoutMode,
+        outputMode,
+        child = logicalPlan,
+        initialState.groupingAttributes,
+        initialState.dataAttributes,
+        initialState.queryExecution.analyzed
+      )
+    )
+  }
+
   /**
    * (Scala-specific)
    * Reduces the elements of each group of data using the specified binary 
function.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f77d0fef4eb9..cc212d99f299 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -752,7 +752,9 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case TransformWithState(
         keyDeserializer, valueDeserializer, groupingAttributes,
         dataAttributes, statefulProcessor, timeoutMode, outputMode,
-        keyEncoder, outputAttr, child) =>
+        keyEncoder, outputAttr, child, hasInitialState,
+        initialStateGroupingAttrs, initialStateDataAttrs,
+        initialStateDeserializer, initialState) =>
         val execPlan = TransformWithStateExec(
           keyDeserializer,
           valueDeserializer,
@@ -767,7 +769,13 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           batchTimestampMs = None,
           eventTimeWatermarkForLateEvents = None,
           eventTimeWatermarkForEviction = None,
-          planLater(child))
+          planLater(child),
+          isStreaming = true,
+          hasInitialState,
+          initialStateGroupingAttrs,
+          initialStateDataAttrs,
+          initialStateDeserializer,
+          planLater(initialState))
         execPlan :: Nil
       case _ =>
         Nil
@@ -918,10 +926,14 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         ) :: Nil
       case logical.TransformWithState(keyDeserializer, valueDeserializer, 
groupingAttributes,
           dataAttributes, statefulProcessor, timeoutMode, outputMode, 
keyEncoder,
-          outputObjAttr, child) =>
+          outputObjAttr, child, hasInitialState,
+          initialStateGroupingAttrs, initialStateDataAttrs,
+          initialStateDeserializer, initialState) =>
         
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, 
valueDeserializer,
           groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, 
outputMode,
-          keyEncoder, outputObjAttr, planLater(child)) :: Nil
+          keyEncoder, outputObjAttr, planLater(child), hasInitialState,
+          initialStateGroupingAttrs, initialStateDataAttrs,
+          initialStateDeserializer, planLater(initialState)) :: Nil
 
       case _: FlatMapGroupsInPandasWithState =>
         // TODO(SPARK-40443): support applyInPandasWithState in batch query
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 14007eb4b101..cfccfff3a138 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -268,11 +268,13 @@ class IncrementalExecution(
         )
 
       case t: TransformWithStateExec =>
+        val hasInitialState = (currentBatchId == 0L && t.hasInitialState)
         t.copy(
           stateInfo = Some(nextStatefulOperationStateInfo()),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
           eventTimeWatermarkForLateEvents = None,
-          eventTimeWatermarkForEviction = None
+          eventTimeWatermarkForEviction = None,
+          hasInitialState = hasInitialState
         )
 
       case m: FlatMapGroupsInPandasWithStateExec =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index d3640ebd8850..36b957f9d430 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 import java.util.UUID
 import java.util.concurrent.TimeUnit.NANOSECONDS
 
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -26,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, 
Attribute, Expressi
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution._
+import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, 
TimeoutMode}
+import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, 
StatefulProcessorWithInitialState, TimeoutMode}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
@@ -65,8 +67,13 @@ case class TransformWithStateExec(
     eventTimeWatermarkForLateEvents: Option[Long],
     eventTimeWatermarkForEviction: Option[Long],
     child: SparkPlan,
-    isStreaming: Boolean = true)
-  extends UnaryExecNode with StateStoreWriter with WatermarkSupport with 
ObjectProducerExec {
+    isStreaming: Boolean = true,
+    hasInitialState: Boolean = false,
+    initialStateGroupingAttrs: Seq[Attribute],
+    initialStateDataAttrs: Seq[Attribute],
+    initialStateDeserializer: Expression,
+    initialState: SparkPlan)
+  extends BinaryExecNode with StateStoreWriter with WatermarkSupport with 
ObjectProducerExec {
 
   override def shortName: String = "transformWithStateExec"
 
@@ -85,8 +92,13 @@ case class TransformWithStateExec(
     }
   }
 
-  override protected def withNewChildInternal(
-    newChild: SparkPlan): TransformWithStateExec = copy(child = newChild)
+  override def left: SparkPlan = child
+
+  override def right: SparkPlan = initialState
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec =
+    copy(child = newLeft, initialState = newRight)
 
   override def keyExpressions: Seq[Attribute] = groupingAttributes
 
@@ -94,14 +106,25 @@ case class TransformWithStateExec(
 
   protected val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
 
+  /**
+   * Distribute by grouping attributes - We need the underlying data and the 
initial state data
+   * to have the same grouping so that the data are co-located on the same 
task.
+   */
   override def requiredChildDistribution: Seq[Distribution] = {
-    StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes,
-      getStateInfo, conf) ::
-      Nil
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes, getStateInfo, conf) ::
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+        initialStateGroupingAttrs, getStateInfo, conf) ::
+    Nil
   }
 
+  /**
+   * We need the initial state to also use the ordering as the data so that we 
can co-locate the
+   * keys from the underlying data and the initial state.
+   */
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
-    groupingAttributes.map(SortOrder(_, Ascending)))
+    groupingAttributes.map(SortOrder(_, Ascending)),
+    initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
 
   private def handleInputRows(keyRow: UnsafeRow, valueRowIter: 
Iterator[InternalRow]):
     Iterator[InternalRow] = {
@@ -127,6 +150,33 @@ case class TransformWithStateExec(
     mappedIterator
   }
 
+  private def processInitialStateRows(
+      keyRow: UnsafeRow,
+      initStateIter: Iterator[InternalRow]): Unit = {
+    val getKeyObj =
+      ObjectOperator.deserializeRowToObject(keyDeserializer, 
groupingAttributes)
+
+    val getInitStateValueObj =
+      ObjectOperator.deserializeRowToObject(initialStateDeserializer, 
initialStateDataAttrs)
+
+    val keyObj = getKeyObj(keyRow) // convert key to objects
+    ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
+    val initStateObjIter = initStateIter.map(getInitStateValueObj.apply)
+
+    var seenInitStateOnKey = false
+    initStateObjIter.foreach { initState =>
+      // cannot re-initialize state on the same grouping key during initial 
state handling
+      if (seenInitStateOnKey) {
+        throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString)
+      }
+      seenInitStateOnKey = true
+      statefulProcessor
+        .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
+        .handleInitialState(keyObj, initState)
+    }
+    ImplicitGroupingKeyTracker.removeImplicitKey()
+  }
+
   private def processNewData(dataIter: Iterator[InternalRow]): 
Iterator[InternalRow] = {
     val groupedIter = GroupedIterator(dataIter, groupingAttributes, 
child.output)
     groupedIter.flatMap { case (keyRow, valueRowIter) =>
@@ -263,58 +313,108 @@ case class TransformWithStateExec(
       case _ =>
     }
 
-    if (isStreaming) {
-      child.execute().mapPartitionsWithStateStore[InternalRow](
+    if (hasInitialState) {
+      val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+      val hadoopConfBroadcast = sparkContext.broadcast(
+        new 
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+      child.execute().stateStoreAwareZipPartitions(
+        initialState.execute(),
         getStateInfo,
-        schemaForKeyRow,
-        schemaForValueRow,
-        NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-        session.sqlContext.sessionState,
-        Some(session.sqlContext.streams.stateStoreCoordinator),
-        useColumnFamilies = true,
-        useMultipleValuesPerKey = true
-      ) {
-        case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-          processData(store, singleIterator)
+        storeNames = Seq(),
+        session.sqlContext.streams.stateStoreCoordinator) {
+        // The state store aware zip partitions will provide us with two 
iterators,
+        // child data iterator and the initial state iterator per partition.
+        case (partitionId, childDataIterator, initStateIterator) =>
+          if (isStreaming) {
+            val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+              stateInfo.get.operatorId, partitionId)
+            val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
+            val store = StateStore.get(
+              storeProviderId = storeProviderId,
+              keySchema = schemaForKeyRow,
+              valueSchema = schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              version = stateInfo.get.storeVersion,
+              useColumnFamilies = true,
+              storeConf = storeConf,
+              hadoopConf = hadoopConfBroadcast.value.value
+            )
+
+            processDataWithInitialState(store, childDataIterator, 
initStateIterator)
+          } else {
+            initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) 
{ store =>
+              processDataWithInitialState(store, childDataIterator, 
initStateIterator)
+            }
+          }
       }
     } else {
-      // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
-      // a temp directory on the executors in mapPartitionsWithIndex.
-      val broadcastedHadoopConf =
-        new SerializableConfiguration(session.sessionState.newHadoopConf())
-      child.execute().mapPartitionsWithIndex[InternalRow](
-        (i, iter) => {
-          val providerId = {
-            val tempDirPath = Utils.createTempDir().getAbsolutePath
-            new StateStoreProviderId(
-              StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId)
+      if (isStreaming) {
+        child.execute().mapPartitionsWithStateStore[InternalRow](
+          getStateInfo,
+          schemaForKeyRow,
+          schemaForValueRow,
+          NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+          session.sqlContext.sessionState,
+          Some(session.sqlContext.streams.stateStoreCoordinator),
+          useColumnFamilies = true
+        ) {
+          case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+            processData(store, singleIterator)
+        }
+      } else {
+        // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
+        // a temp directory on the executors in mapPartitionsWithIndex.
+        val hadoopConfBroadcast = sparkContext.broadcast(
+          new 
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+        child.execute().mapPartitionsWithIndex[InternalRow](
+          (i: Int, iter: Iterator[InternalRow]) => {
+            initNewStateStoreAndProcessData(i, hadoopConfBroadcast) { store =>
+              processData(store, iter)
+            }
           }
+        )
+      }
+    }
+  }
 
-          val sqlConf = new SQLConf()
-          sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
-            classOf[RocksDBStateStoreProvider].getName)
-          val storeConf = new StateStoreConf(sqlConf)
-
-          // Create StateStoreProvider for this partition
-          val stateStoreProvider = StateStoreProvider.createAndInit(
-            providerId,
-            schemaForKeyRow,
-            schemaForValueRow,
-            NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-            useColumnFamilies = true,
-            storeConf = storeConf,
-            hadoopConf = broadcastedHadoopConf.value,
-            useMultipleValuesPerKey = true)
-
-          val store = stateStoreProvider.getStore(0)
-          val outputIterator = processData(store, iter)
-          CompletionIterator[InternalRow, 
Iterator[InternalRow]](outputIterator.iterator, {
-            stateStoreProvider.close()
-            statefulProcessor.close()
-          })
-        }
-      )
+  /**
+   * Create a new StateStore for given partitionId and instantiate a temp 
directory
+   * on the executors. Process data and close the stateStore provider 
afterwards.
+   */
+  private def initNewStateStoreAndProcessData(
+      partitionId: Int,
+      hadoopConfBroadcast: Broadcast[SerializableConfiguration])
+    (f: StateStore => CompletionIterator[InternalRow, Iterator[InternalRow]]):
+    CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+
+    val providerId = {
+      val tempDirPath = Utils.createTempDir().getAbsolutePath
+      new StateStoreProviderId(
+        StateStoreId(tempDirPath, 0, partitionId), getStateInfo.queryRunId)
     }
+
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+    val storeConf = new StateStoreConf(sqlConf)
+
+    // Create StateStoreProvider for this partition
+    val stateStoreProvider = StateStoreProvider.createAndInit(
+      providerId,
+      schemaForKeyRow,
+      schemaForValueRow,
+      NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+      useColumnFamilies = true,
+      storeConf = storeConf,
+      hadoopConf = hadoopConfBroadcast.value.value,
+      useMultipleValuesPerKey = true)
+
+    val store = stateStoreProvider.getStore(0)
+    val outputIterator = f(store)
+    CompletionIterator[InternalRow, 
Iterator[InternalRow]](outputIterator.iterator, {
+      stateStoreProvider.close()
+      statefulProcessor.close()
+    })
   }
 
   /**
@@ -333,8 +433,37 @@ case class TransformWithStateExec(
     processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
     processDataWithPartition(singleIterator, store, processorHandle)
   }
+
+  private def processDataWithInitialState(
+      store: StateStore,
+      childDataIterator: Iterator[InternalRow],
+      initStateIterator: Iterator[InternalRow]):
+    CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+    val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
+      keyEncoder, timeoutMode, isStreaming)
+    assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
+    statefulProcessor.setHandle(processorHandle)
+    statefulProcessor.init(outputMode, timeoutMode)
+    processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+
+    // Check if is first batch
+    // Only process initial states for first batch
+    if (processorHandle.getQueryInfo().getBatchId == 0) {
+      // If the user provided initial state, we need to have the initial state 
and the
+      // data in the same partition so that we can still have just one commit 
at the end.
+      val groupedInitialStateIter = GroupedIterator(initStateIterator,
+        initialStateGroupingAttrs, initialState.output)
+      groupedInitialStateIter.foreach {
+        case (keyRow, valueRowIter) =>
+          processInitialStateRows(keyRow.asInstanceOf[UnsafeRow], valueRowIter)
+      }
+    }
+
+    processDataWithPartition(childDataIterator, store, processorHandle)
+  }
 }
 
+// scalastyle:off
 object TransformWithStateExec {
 
   // Plan logical transformWithState for batch queries
@@ -348,7 +477,12 @@ object TransformWithStateExec {
       outputMode: OutputMode,
       keyEncoder: ExpressionEncoder[Any],
       outputObjAttr: Attribute,
-      child: SparkPlan): SparkPlan = {
+      child: SparkPlan,
+      hasInitialState: Boolean = false,
+      initialStateGroupingAttrs: Seq[Attribute],
+      initialStateDataAttrs: Seq[Attribute],
+      initialStateDeserializer: Expression,
+      initialState: SparkPlan): SparkPlan = {
     val shufflePartitions = 
child.session.sessionState.conf.numShufflePartitions
     val statefulOperatorStateInfo = StatefulOperatorStateInfo(
       checkpointLocation = "", // empty checkpointLocation will be populated 
in doExecute
@@ -373,6 +507,12 @@ object TransformWithStateExec {
       None,
       None,
       child,
-      isStreaming = false)
+      isStreaming = false,
+      hasInitialState,
+      initialStateGroupingAttrs,
+      initialStateDataAttrs,
+      initialStateDeserializer,
+      initialState)
   }
 }
+// scalastyle:on
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index a8d4c06bc83c..2f72cbb0b0fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -112,6 +112,11 @@ object StateStoreErrors {
       handleState: String): 
StatefulProcessorCannotPerformOperationWithInvalidHandleState = {
     new 
StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, 
handleState)
   }
+
+  def cannotReInitializeStateOnKey(groupingKey: String):
+    StatefulProcessorCannotReInitializeState = {
+    new StatefulProcessorCannotReInitializeState(groupingKey)
+  }
 }
 
 class 
StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: 
String)
@@ -157,6 +162,11 @@ class 
StatefulProcessorCannotPerformOperationWithInvalidHandleState(
     messageParameters = Map("operationType" -> operationType, "handleState" -> 
handleState)
   )
 
+class StatefulProcessorCannotReInitializeState(groupingKey: String)
+  extends SparkUnsupportedOperationException(
+  errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
+  messageParameters = Map("groupingKey" -> groupingKey))
+
 class StateStoreUnsupportedOperationOnMissingColumnFamily(
     operationType: String,
     colFamilyName: String) extends SparkUnsupportedOperationException(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index d7c5ce3815b0..db8cb8b810af 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 
 case class InputMapRow(key: String, action: String, value: (String, String))
@@ -82,7 +82,8 @@ class TestMapStateProcessor
  * Class that adds integration tests for MapState types used in arbitrary 
stateful
  * operators such as transformWithState.
  */
-class TransformWithMapStateSuite extends StreamTest {
+class TransformWithMapStateSuite extends StreamTest
+  with AlsoTestWithChangelogCheckpointingEnabled {
   import testImplicits._
 
   private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
new file mode 100644
index 000000000000..9f2e2c2d9f02
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InitInputRow(key: String, action: String, value: Double)
+case class InputRowForInitialState(
+    key: String, value: Double, entries: List[Double], mapping: Map[Double, 
Int])
+
+abstract class StatefulProcessorWithInitialStateTestClass[V]
+    extends StatefulProcessorWithInitialState[
+        String, InitInputRow, (String, String, Double), V] {
+  @transient var _valState: ValueState[Double] = _
+  @transient var _listState: ListState[Double] = _
+  @transient var _mapState: MapState[Double, Int] = _
+
+  override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = {
+    _valState = getHandle.getValueState[Double]("testValueInit", 
Encoders.scalaDouble)
+    _listState = getHandle.getListState[Double]("testListInit", 
Encoders.scalaDouble)
+    _mapState = getHandle.getMapState[Double, Int](
+      "testMapInit", Encoders.scalaDouble, Encoders.scalaInt)
+  }
+
+  override def close(): Unit = {}
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InitInputRow],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] 
= {
+    var output = List[(String, String, Double)]()
+    for (row <- inputRows) {
+      if (row.action == "getOption") {
+        output = (key, row.action, _valState.getOption().getOrElse(-1.0)) :: 
output
+      } else if (row.action == "update") {
+        _valState.update(row.value)
+      } else if (row.action == "remove") {
+        _valState.clear()
+      } else if (row.action == "getList") {
+        _listState.get().foreach { element =>
+          output = (key, row.action, element) :: output
+        }
+      } else if (row.action == "appendList") {
+        _listState.appendValue(row.value)
+      } else if (row.action == "clearList") {
+        _listState.clear()
+      } else if (row.action == "getCount") {
+        val count =
+          if (!_mapState.containsKey(row.value)) 0
+          else _mapState.getValue(row.value)
+        output = (key, row.action, count.toDouble) :: output
+      } else if (row.action == "incCount") {
+        val count =
+          if (!_mapState.containsKey(row.value)) 0
+          else _mapState.getValue(row.value)
+        _mapState.updateValue(row.value, count + 1)
+      } else if (row.action == "clearCount") {
+        _mapState.removeKey(row.value)
+      }
+    }
+    output.iterator
+  }
+}
+
+class AccumulateStatefulProcessorWithInitState
+    extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
+  override def handleInitialState(
+      key: String,
+      initialState: (String, Double)): Unit = {
+    _valState.update(initialState._2)
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InitInputRow],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] 
= {
+    var output = List[(String, String, Double)]()
+    for (row <- inputRows) {
+      if (row.action == "getOption") {
+        output = (key, row.action, _valState.getOption().getOrElse(0.0)) :: 
output
+      } else if (row.action == "add") {
+        // Update state variable as accumulative sum
+        val accumulateSum = _valState.getOption().getOrElse(0.0) + row.value
+        _valState.update(accumulateSum)
+      } else if (row.action == "remove") {
+        _valState.clear()
+      }
+    }
+    output.iterator
+  }
+}
+
+class InitialStateInMemoryTestClass
+  extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] {
+  override def handleInitialState(
+      key: String,
+      initialState: InputRowForInitialState): Unit = {
+    _valState.update(initialState.value)
+    _listState.appendList(initialState.entries.toArray)
+    val inMemoryMap = initialState.mapping
+    inMemoryMap.foreach { kvPair =>
+      _mapState.updateValue(kvPair._1, kvPair._2)
+    }
+  }
+}
+
+/**
+ * Class that adds tests for transformWithState stateful
+ * streaming operator with user-defined initial state
+ */
+class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
+  with AlsoTestWithChangelogCheckpointingEnabled {
+
+  import testImplicits._
+
+  private def createInitialDfForTest: KeyValueGroupedDataset[String, (String, 
Double)] = {
+    Seq(("init_1", 40.0), ("init_2", 100.0)).toDS()
+      .groupByKey(x => x._1)
+      .mapValues(x => x)
+  }
+
+
+  test("transformWithStateWithInitialState - correctness test, " +
+    "run with multiple state variables - in-memory type") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InitInputRow]
+      val kvDataSet = inputData.toDS()
+        .groupByKey(x => x.key)
+      val initStateDf =
+        Seq(InputRowForInitialState("init_1", 40.0, List(40.0), Map(40.0 -> 
1)),
+          InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 
1)))
+          .toDS().groupByKey(x => x.key).mapValues(x => x)
+      val query = kvDataSet.transformWithState(new 
InitialStateInMemoryTestClass(),
+            TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf)
+
+      testStream(query, OutputMode.Update())(
+        // non-exist key test
+        AddData(inputData, InitInputRow("k1", "update", 37.0)),
+        AddData(inputData, InitInputRow("k2", "update", 40.0)),
+        AddData(inputData, InitInputRow("non-exist", "getOption", -1.0)),
+        CheckNewAnswer(("non-exist", "getOption", -1.0)),
+        AddData(inputData, InitInputRow("k1", "appendList", 37.0)),
+        AddData(inputData, InitInputRow("k2", "appendList", 40.0)),
+        AddData(inputData, InitInputRow("non-exist", "getList", -1.0)),
+        CheckNewAnswer(),
+
+        AddData(inputData, InitInputRow("k1", "incCount", 37.0)),
+        AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
+        AddData(inputData, InitInputRow("non-exist", "getCount", -1.0)),
+        CheckNewAnswer(("non-exist", "getCount", 0.0)),
+        AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
+        AddData(inputData, InitInputRow("k2", "getCount", 40.0)),
+        CheckNewAnswer(("k2", "getCount", 2.0)),
+
+        // test every row in initial State is processed
+        AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+        CheckNewAnswer(("init_1", "getOption", 40.0)),
+        AddData(inputData, InitInputRow("init_2", "getOption", -1.0)),
+        CheckNewAnswer(("init_2", "getOption", 100.0)),
+
+        AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+        CheckNewAnswer(("init_1", "getList", 40.0)),
+        AddData(inputData, InitInputRow("init_2", "getList", -1.0)),
+        CheckNewAnswer(("init_2", "getList", 100.0)),
+
+        AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
+        CheckNewAnswer(("init_1", "getCount", 1.0)),
+        AddData(inputData, InitInputRow("init_2", "getCount", 100.0)),
+        CheckNewAnswer(("init_2", "getCount", 1.0)),
+
+        // Update row with key in initial row will work
+        AddData(inputData, InitInputRow("init_1", "update", 50.0)),
+        AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+        CheckNewAnswer(("init_1", "getOption", 50.0)),
+        AddData(inputData, InitInputRow("init_1", "remove", -1.0)),
+        AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+        CheckNewAnswer(("init_1", "getOption", -1.0)),
+
+        AddData(inputData, InitInputRow("init_1", "appendList", 50.0)),
+        AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+        CheckNewAnswer(("init_1", "getList", 50.0), ("init_1", "getList", 
40.0)),
+
+        AddData(inputData, InitInputRow("init_1", "incCount", 40.0)),
+        AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
+        CheckNewAnswer(("init_1", "getCount", 2.0)),
+
+        // test remove
+        AddData(inputData, InitInputRow("k1", "remove", -1.0)),
+        AddData(inputData, InitInputRow("k1", "getOption", -1.0)),
+        CheckNewAnswer(("k1", "getOption", -1.0)),
+
+        AddData(inputData, InitInputRow("init_1", "clearCount", -1.0)),
+        AddData(inputData, InitInputRow("init_1", "getCount", -1.0)),
+        CheckNewAnswer(("init_1", "getCount", 0.0)),
+
+        AddData(inputData, InitInputRow("init_1", "clearList", -1.0)),
+        AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+        CheckNewAnswer()
+      )
+    }
+  }
+
+  test("transformWithStateWithInitialState -" +
+    " correctness test, processInitialState should only run once") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val initStateDf = createInitialDfForTest
+      val inputData = MemoryStream[InitInputRow]
+      val query = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+          TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf
+        )
+      testStream(query, OutputMode.Update())(
+        AddData(inputData, InitInputRow("init_1", "add", 50.0)),
+        AddData(inputData, InitInputRow("init_2", "add", 60.0)),
+        AddData(inputData, InitInputRow("init_1", "add", 50.0)),
+        // If processInitialState was processed multiple times,
+        // following checks will fail
+        AddData(inputData,
+          InitInputRow("init_1", "getOption", -1.0), InitInputRow("init_2", 
"getOption", -1.0)),
+        CheckNewAnswer(("init_2", "getOption", 160.0), ("init_1", "getOption", 
140.0))
+      )
+    }
+  }
+
+  test("transformWithStateWithInitialState - batch should succeed") {
+    val inputData = Seq(InitInputRow("k1", "add", 37.0), InitInputRow("k1", 
"getOption", -1.0))
+    val result = inputData.toDS()
+      .groupByKey(x => x.key)
+      .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+        TimeoutMode.NoTimeouts(),
+        OutputMode.Append(),
+        createInitialDfForTest)
+
+    val df = result.toDF()
+    checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF())
+  }
+
+  test("transformWithStateWithInitialState - " +
+    "cannot re-initialize state during initial state handling") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1", 
50.0)).toDS()
+        .groupByKey(x => x._1).mapValues(x => x)
+      val inputData = MemoryStream[InitInputRow]
+      val query = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Append(),
+          initDf)
+
+      testStream(query, OutputMode.Update())(
+        AddData(inputData, InitInputRow("k1", "add", 50.0)),
+        Execute { q =>
+          val e = intercept[Exception] {
+            q.processAllAvailable()
+          }
+          checkError(
+            exception = 
e.getCause.asInstanceOf[SparkUnsupportedOperationException],
+            errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
+            sqlState = Some("42802"),
+            parameters = Map("groupingKey" -> "init_1")
+          )
+        }
+      )
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 24b0d59c45c5..24e68e3db9d8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -769,4 +769,24 @@ class TransformWithStateValidationSuite extends 
StateStoreMetricsTest {
       }
     )
   }
+
+  test("transformWithStateWithInitialState - streaming with 
hdfsStateStoreProvider should fail") {
+    val inputData = MemoryStream[InitInputRow]
+    val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS()
+      .groupByKey(x => x._1)
+      .mapValues(x => x)
+    val result = inputData.toDS()
+      .groupByKey(x => x.key)
+      .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+        TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf
+      )
+    testStream(result, OutputMode.Update())(
+      AddData(inputData, InitInputRow("a", "add", -1.0)),
+      ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] {
+        (t: Throwable) => {
+          assert(t.getMessage.contains("not supported"))
+        }
+      }
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to