This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4d72be3abdc4 [SPARK-47363][SS] Initial State without state reader implementation for State API v2 4d72be3abdc4 is described below commit 4d72be3abdc4c651da029bdbd24a574099d45e7c Author: jingz-db <jing.z...@databricks.com> AuthorDate: Thu Mar 28 14:50:46 2024 +0900 [SPARK-47363][SS] Initial State without state reader implementation for State API v2 ### What changes were proposed in this pull request? This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2. Note that populating the initial state will only happen for the first batch of the new streaming query. Trying to re-initialize state for the same grouping key will result in an error. ### Why are the changes needed? These changes are needed to support initial state. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939 ### Does this PR introduce _any_ user-facing change? Yes. This PR introduces a new function: ``` def transformWithState( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] ``` ### How was this patch tested? Unit tests in `TransformWithStateWithInitialStateSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45467 from jingz-db/initial-state-state-v2. Lead-authored-by: jingz-db <jing.z...@databricks.com> Co-authored-by: Jing Zhan <135738831+jingz...@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../src/main/resources/error/error-classes.json | 6 + docs/sql-error-conditions.md | 6 + .../spark/sql/streaming/StatefulProcessor.scala | 19 ++ .../spark/sql/catalyst/plans/logical/object.scala | 55 +++- .../apache/spark/sql/KeyValueGroupedDataset.scala | 38 ++- .../spark/sql/execution/SparkStrategies.scala | 20 +- .../execution/streaming/IncrementalExecution.scala | 4 +- .../streaming/TransformWithStateExec.scala | 254 ++++++++++++++---- .../streaming/state/StateStoreErrors.scala | 10 + .../sql/streaming/TransformWithMapStateSuite.scala | 5 +- .../TransformWithStateInitialStateSuite.scala | 293 +++++++++++++++++++++ .../sql/streaming/TransformWithStateSuite.scala | 20 ++ 12 files changed, 661 insertions(+), 69 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 185e86853dfd..11c8204d2c93 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3553,6 +3553,12 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : { + "message" : [ + "Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=<groupingKey>." + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=<colFamilyName>." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 838ca2fa33c9..85b9e85ac420 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2162,6 +2162,12 @@ Failed to perform stateful processor operation=`<operationType>` with invalid ha Failed to perform stateful processor operation=`<operationType>` with invalid timeoutMode=`<timeoutMode>` +### STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=`<groupingKey>`. + ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index ad9b807ddf5a..1a61972f0ed0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -91,3 +91,22 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { statefulProcessorHandle } } + +/** + * Stateful processor with support for specifying initial state. + * Accepts a user-defined type as initial state to be initialized in the first batch. + * This can be used for starting a new streaming query with existing state from a + * previous streaming query. + */ +@Experimental +@Evolving +trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] { + + /** + * Function that will be invoked only in the first batch for users to process initial states. + * + * @param key - grouping key + * @param initialState - A row in the initial state to be processed + */ + def handleInitialState(key: K, initialState: S): Unit +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index cb8673d20ed3..b2c443a8cce0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -588,7 +588,46 @@ object TransformWithState { outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], CatalystSerde.generateObjAttr[U], - child + child, + hasInitialState = false, + // the following parameters will not be used in physical plan if hasInitialState = false + initialStateGroupingAttrs = groupingAttributes, + initialStateDataAttrs = dataAttributes, + initialStateDeserializer = + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + initialState = LocalRelation(encoderFor[K].schema) // empty data set + ) + CatalystSerde.serialize[U](mapped) + } + + // This apply() is to invoke TransformWithState object with hasInitialState set to true + def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder]( + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[K, V, U], + timeoutMode: TimeoutMode, + outputMode: OutputMode, + child: LogicalPlan, + initialStateGroupingAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialState: LogicalPlan): LogicalPlan = { + val keyEncoder = encoderFor[K] + val mapped = new TransformWithState( + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), + groupingAttributes, + dataAttributes, + statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + timeoutMode, + outputMode, + keyEncoder.asInstanceOf[ExpressionEncoder[Any]], + CatalystSerde.generateObjAttr[U], + child, + hasInitialState = true, + initialStateGroupingAttrs, + initialStateDataAttrs, + UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs), + initialState ) CatalystSerde.serialize[U](mapped) } @@ -604,10 +643,18 @@ case class TransformWithState( outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer { + child: LogicalPlan, + hasInitialState: Boolean = false, + initialStateGroupingAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialStateDeserializer: Expression, + initialState: LogicalPlan) extends BinaryNode with ObjectProducer { - override protected def withNewChildInternal(newChild: LogicalPlan): TransformWithState = - copy(child = newChild) + override def left: LogicalPlan = child + override def right: LogicalPlan = initialState + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState = + copy(child = newLeft, initialState = newRight) } /** Factory for constructing new `FlatMapGroupsInR` nodes. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 50ab2a41612b..95ad973aee51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql]( ) } + /** + * (Scala-specific) + * Invokes methods defined in the stateful processor used in arbitrary state API v2. + * Functions as the function above, but with additional initial state. + * + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @tparam S The type of initial state objects. Must be encodable to Spark SQL types. + * @param statefulProcessor Instance of statefulProcessor whose functions will + * be invoked by the operator. + * @param timeoutMode The timeout mode of the stateful processor. + * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. + * @param initialState User provided initial state that will be used to initiate state for + * the query in the first batch. + * + */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeoutMode: TimeoutMode, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + Dataset[U]( + sparkSession, + TransformWithState[K, V, U, S]( + groupingAttributes, + dataAttributes, + statefulProcessor, + timeoutMode, + outputMode, + child = logicalPlan, + initialState.groupingAttributes, + initialState.dataAttributes, + initialState.queryExecution.analyzed + ) + ) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f77d0fef4eb9..cc212d99f299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -752,7 +752,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, - keyEncoder, outputAttr, child) => + keyEncoder, outputAttr, child, hasInitialState, + initialStateGroupingAttrs, initialStateDataAttrs, + initialStateDeserializer, initialState) => val execPlan = TransformWithStateExec( keyDeserializer, valueDeserializer, @@ -767,7 +769,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, - planLater(child)) + planLater(child), + isStreaming = true, + hasInitialState, + initialStateGroupingAttrs, + initialStateDataAttrs, + initialStateDeserializer, + planLater(initialState)) execPlan :: Nil case _ => Nil @@ -918,10 +926,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, - outputObjAttr, child) => + outputObjAttr, child, hasInitialState, + initialStateGroupingAttrs, initialStateDataAttrs, + initialStateDeserializer, initialState) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, - keyEncoder, outputObjAttr, planLater(child)) :: Nil + keyEncoder, outputObjAttr, planLater(child), hasInitialState, + initialStateGroupingAttrs, initialStateDataAttrs, + initialStateDeserializer, planLater(initialState)) :: Nil case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-40443): support applyInPandasWithState in batch query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 14007eb4b101..cfccfff3a138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -268,11 +268,13 @@ class IncrementalExecution( ) case t: TransformWithStateExec => + val hasInitialState = (currentBatchId == 0L && t.hasInitialState) t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermarkForLateEvents = None, - eventTimeWatermarkForEviction = None + eventTimeWatermarkForEviction = None, + hasInitialState = hasInitialState ) case m: FlatMapGroupsInPandasWithStateExec => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index d3640ebd8850..36b957f9d430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -26,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} import org.apache.spark.sql.types._ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} @@ -65,8 +67,13 @@ case class TransformWithStateExec( eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], child: SparkPlan, - isStreaming: Boolean = true) - extends UnaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { + isStreaming: Boolean = true, + hasInitialState: Boolean = false, + initialStateGroupingAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialStateDeserializer: Expression, + initialState: SparkPlan) + extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { override def shortName: String = "transformWithStateExec" @@ -85,8 +92,13 @@ case class TransformWithStateExec( } } - override protected def withNewChildInternal( - newChild: SparkPlan): TransformWithStateExec = copy(child = newChild) + override def left: SparkPlan = child + + override def right: SparkPlan = initialState + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec = + copy(child = newLeft, initialState = newRight) override def keyExpressions: Seq[Attribute] = groupingAttributes @@ -94,14 +106,25 @@ case class TransformWithStateExec( protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + /** + * Distribute by grouping attributes - We need the underlying data and the initial state data + * to have the same grouping so that the data are co-located on the same task. + */ override def requiredChildDistribution: Seq[Distribution] = { - StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, - getStateInfo, conf) :: - Nil + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: + StatefulOperatorPartitioning.getCompatibleDistribution( + initialStateGroupingAttrs, getStateInfo, conf) :: + Nil } + /** + * We need the initial state to also use the ordering as the data so that we can co-locate the + * keys from the underlying data and the initial state. + */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( - groupingAttributes.map(SortOrder(_, Ascending))) + groupingAttributes.map(SortOrder(_, Ascending)), + initialStateGroupingAttrs.map(SortOrder(_, Ascending))) private def handleInputRows(keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow]): Iterator[InternalRow] = { @@ -127,6 +150,33 @@ case class TransformWithStateExec( mappedIterator } + private def processInitialStateRows( + keyRow: UnsafeRow, + initStateIter: Iterator[InternalRow]): Unit = { + val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + + val getInitStateValueObj = + ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs) + + val keyObj = getKeyObj(keyRow) // convert key to objects + ImplicitGroupingKeyTracker.setImplicitKey(keyObj) + val initStateObjIter = initStateIter.map(getInitStateValueObj.apply) + + var seenInitStateOnKey = false + initStateObjIter.foreach { initState => + // cannot re-initialize state on the same grouping key during initial state handling + if (seenInitStateOnKey) { + throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString) + } + seenInitStateOnKey = true + statefulProcessor + .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] + .handleInitialState(keyObj, initState) + } + ImplicitGroupingKeyTracker.removeImplicitKey() + } + private def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => @@ -263,58 +313,108 @@ case class TransformWithStateExec( case _ => } - if (isStreaming) { - child.execute().mapPartitionsWithStateStore[InternalRow]( + if (hasInitialState) { + val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) + val hadoopConfBroadcast = sparkContext.broadcast( + new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf())) + child.execute().stateStoreAwareZipPartitions( + initialState.execute(), getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), - session.sqlContext.sessionState, - Some(session.sqlContext.streams.stateStoreCoordinator), - useColumnFamilies = true, - useMultipleValuesPerKey = true - ) { - case (store: StateStore, singleIterator: Iterator[InternalRow]) => - processData(store, singleIterator) + storeNames = Seq(), + session.sqlContext.streams.stateStoreCoordinator) { + // The state store aware zip partitions will provide us with two iterators, + // child data iterator and the initial state iterator per partition. + case (partitionId, childDataIterator, initStateIterator) => + if (isStreaming) { + val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation, + stateInfo.get.operatorId, partitionId) + val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) + val store = StateStore.get( + storeProviderId = storeProviderId, + keySchema = schemaForKeyRow, + valueSchema = schemaForValueRow, + NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + version = stateInfo.get.storeVersion, + useColumnFamilies = true, + storeConf = storeConf, + hadoopConf = hadoopConfBroadcast.value.value + ) + + processDataWithInitialState(store, childDataIterator, initStateIterator) + } else { + initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) { store => + processDataWithInitialState(store, childDataIterator, initStateIterator) + } + } } } else { - // If the query is running in batch mode, we need to create a new StateStore and instantiate - // a temp directory on the executors in mapPartitionsWithIndex. - val broadcastedHadoopConf = - new SerializableConfiguration(session.sessionState.newHadoopConf()) - child.execute().mapPartitionsWithIndex[InternalRow]( - (i, iter) => { - val providerId = { - val tempDirPath = Utils.createTempDir().getAbsolutePath - new StateStoreProviderId( - StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId) + if (isStreaming) { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateInfo, + schemaForKeyRow, + schemaForValueRow, + NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + session.sqlContext.sessionState, + Some(session.sqlContext.streams.stateStoreCoordinator), + useColumnFamilies = true + ) { + case (store: StateStore, singleIterator: Iterator[InternalRow]) => + processData(store, singleIterator) + } + } else { + // If the query is running in batch mode, we need to create a new StateStore and instantiate + // a temp directory on the executors in mapPartitionsWithIndex. + val hadoopConfBroadcast = sparkContext.broadcast( + new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf())) + child.execute().mapPartitionsWithIndex[InternalRow]( + (i: Int, iter: Iterator[InternalRow]) => { + initNewStateStoreAndProcessData(i, hadoopConfBroadcast) { store => + processData(store, iter) + } } + ) + } + } + } - val sqlConf = new SQLConf() - sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key, - classOf[RocksDBStateStoreProvider].getName) - val storeConf = new StateStoreConf(sqlConf) - - // Create StateStoreProvider for this partition - val stateStoreProvider = StateStoreProvider.createAndInit( - providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), - useColumnFamilies = true, - storeConf = storeConf, - hadoopConf = broadcastedHadoopConf.value, - useMultipleValuesPerKey = true) - - val store = stateStoreProvider.getStore(0) - val outputIterator = processData(store, iter) - CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator.iterator, { - stateStoreProvider.close() - statefulProcessor.close() - }) - } - ) + /** + * Create a new StateStore for given partitionId and instantiate a temp directory + * on the executors. Process data and close the stateStore provider afterwards. + */ + private def initNewStateStoreAndProcessData( + partitionId: Int, + hadoopConfBroadcast: Broadcast[SerializableConfiguration]) + (f: StateStore => CompletionIterator[InternalRow, Iterator[InternalRow]]): + CompletionIterator[InternalRow, Iterator[InternalRow]] = { + + val providerId = { + val tempDirPath = Utils.createTempDir().getAbsolutePath + new StateStoreProviderId( + StateStoreId(tempDirPath, 0, partitionId), getStateInfo.queryRunId) } + + val sqlConf = new SQLConf() + sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + val storeConf = new StateStoreConf(sqlConf) + + // Create StateStoreProvider for this partition + val stateStoreProvider = StateStoreProvider.createAndInit( + providerId, + schemaForKeyRow, + schemaForValueRow, + NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + useColumnFamilies = true, + storeConf = storeConf, + hadoopConf = hadoopConfBroadcast.value.value, + useMultipleValuesPerKey = true) + + val store = stateStoreProvider.getStore(0) + val outputIterator = f(store) + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator.iterator, { + stateStoreProvider.close() + statefulProcessor.close() + }) } /** @@ -333,8 +433,37 @@ case class TransformWithStateExec( processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } + + private def processDataWithInitialState( + store: StateStore, + childDataIterator: Iterator[InternalRow], + initStateIterator: Iterator[InternalRow]): + CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, + keyEncoder, timeoutMode, isStreaming) + assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) + statefulProcessor.setHandle(processorHandle) + statefulProcessor.init(outputMode, timeoutMode) + processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + + // Check if is first batch + // Only process initial states for first batch + if (processorHandle.getQueryInfo().getBatchId == 0) { + // If the user provided initial state, we need to have the initial state and the + // data in the same partition so that we can still have just one commit at the end. + val groupedInitialStateIter = GroupedIterator(initStateIterator, + initialStateGroupingAttrs, initialState.output) + groupedInitialStateIter.foreach { + case (keyRow, valueRowIter) => + processInitialStateRows(keyRow.asInstanceOf[UnsafeRow], valueRowIter) + } + } + + processDataWithPartition(childDataIterator, store, processorHandle) + } } +// scalastyle:off object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -348,7 +477,12 @@ object TransformWithStateExec { outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], outputObjAttr: Attribute, - child: SparkPlan): SparkPlan = { + child: SparkPlan, + hasInitialState: Boolean = false, + initialStateGroupingAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialStateDeserializer: Expression, + initialState: SparkPlan): SparkPlan = { val shufflePartitions = child.session.sessionState.conf.numShufflePartitions val statefulOperatorStateInfo = StatefulOperatorStateInfo( checkpointLocation = "", // empty checkpointLocation will be populated in doExecute @@ -373,6 +507,12 @@ object TransformWithStateExec { None, None, child, - isStreaming = false) + isStreaming = false, + hasInitialState, + initialStateGroupingAttrs, + initialStateDataAttrs, + initialStateDeserializer, + initialState) } } +// scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a8d4c06bc83c..2f72cbb0b0fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -112,6 +112,11 @@ object StateStoreErrors { handleState: String): StatefulProcessorCannotPerformOperationWithInvalidHandleState = { new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState) } + + def cannotReInitializeStateOnKey(groupingKey: String): + StatefulProcessorCannotReInitializeState = { + new StatefulProcessorCannotReInitializeState(groupingKey) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) @@ -157,6 +162,11 @@ class StatefulProcessorCannotPerformOperationWithInvalidHandleState( messageParameters = Map("operationType" -> operationType, "handleState" -> handleState) ) +class StatefulProcessorCannotReInitializeState(groupingKey: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", + messageParameters = Map("groupingKey" -> groupingKey)) + class StateStoreUnsupportedOperationOnMissingColumnFamily( operationType: String, colFamilyName: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index d7c5ce3815b0..db8cb8b810af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf case class InputMapRow(key: String, action: String, value: (String, String)) @@ -82,7 +82,8 @@ class TestMapStateProcessor * Class that adds integration tests for MapState types used in arbitrary stateful * operators such as transformWithState. */ -class TransformWithMapStateSuite extends StreamTest { +class TransformWithMapStateSuite extends StreamTest + with AlsoTestWithChangelogCheckpointingEnabled { import testImplicits._ private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala new file mode 100644 index 000000000000..9f2e2c2d9f02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.internal.SQLConf + +case class InitInputRow(key: String, action: String, value: Double) +case class InputRowForInitialState( + key: String, value: Double, entries: List[Double], mapping: Map[Double, Int]) + +abstract class StatefulProcessorWithInitialStateTestClass[V] + extends StatefulProcessorWithInitialState[ + String, InitInputRow, (String, String, Double), V] { + @transient var _valState: ValueState[Double] = _ + @transient var _listState: ListState[Double] = _ + @transient var _mapState: MapState[Double, Int] = _ + + override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + _valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble) + _listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble) + _mapState = getHandle.getMapState[Double, Int]( + "testMapInit", Encoders.scalaDouble, Encoders.scalaInt) + } + + override def close(): Unit = {} + + override def handleInputRows( + key: String, + inputRows: Iterator[InitInputRow], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] = { + var output = List[(String, String, Double)]() + for (row <- inputRows) { + if (row.action == "getOption") { + output = (key, row.action, _valState.getOption().getOrElse(-1.0)) :: output + } else if (row.action == "update") { + _valState.update(row.value) + } else if (row.action == "remove") { + _valState.clear() + } else if (row.action == "getList") { + _listState.get().foreach { element => + output = (key, row.action, element) :: output + } + } else if (row.action == "appendList") { + _listState.appendValue(row.value) + } else if (row.action == "clearList") { + _listState.clear() + } else if (row.action == "getCount") { + val count = + if (!_mapState.containsKey(row.value)) 0 + else _mapState.getValue(row.value) + output = (key, row.action, count.toDouble) :: output + } else if (row.action == "incCount") { + val count = + if (!_mapState.containsKey(row.value)) 0 + else _mapState.getValue(row.value) + _mapState.updateValue(row.value, count + 1) + } else if (row.action == "clearCount") { + _mapState.removeKey(row.value) + } + } + output.iterator + } +} + +class AccumulateStatefulProcessorWithInitState + extends StatefulProcessorWithInitialStateTestClass[(String, Double)] { + override def handleInitialState( + key: String, + initialState: (String, Double)): Unit = { + _valState.update(initialState._2) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InitInputRow], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] = { + var output = List[(String, String, Double)]() + for (row <- inputRows) { + if (row.action == "getOption") { + output = (key, row.action, _valState.getOption().getOrElse(0.0)) :: output + } else if (row.action == "add") { + // Update state variable as accumulative sum + val accumulateSum = _valState.getOption().getOrElse(0.0) + row.value + _valState.update(accumulateSum) + } else if (row.action == "remove") { + _valState.clear() + } + } + output.iterator + } +} + +class InitialStateInMemoryTestClass + extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] { + override def handleInitialState( + key: String, + initialState: InputRowForInitialState): Unit = { + _valState.update(initialState.value) + _listState.appendList(initialState.entries.toArray) + val inMemoryMap = initialState.mapping + inMemoryMap.foreach { kvPair => + _mapState.updateValue(kvPair._1, kvPair._2) + } + } +} + +/** + * Class that adds tests for transformWithState stateful + * streaming operator with user-defined initial state + */ +class TransformWithStateInitialStateSuite extends StateStoreMetricsTest + with AlsoTestWithChangelogCheckpointingEnabled { + + import testImplicits._ + + private def createInitialDfForTest: KeyValueGroupedDataset[String, (String, Double)] = { + Seq(("init_1", 40.0), ("init_2", 100.0)).toDS() + .groupByKey(x => x._1) + .mapValues(x => x) + } + + + test("transformWithStateWithInitialState - correctness test, " + + "run with multiple state variables - in-memory type") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + + val inputData = MemoryStream[InitInputRow] + val kvDataSet = inputData.toDS() + .groupByKey(x => x.key) + val initStateDf = + Seq(InputRowForInitialState("init_1", 40.0, List(40.0), Map(40.0 -> 1)), + InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 1))) + .toDS().groupByKey(x => x.key).mapValues(x => x) + val query = kvDataSet.transformWithState(new InitialStateInMemoryTestClass(), + TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) + + testStream(query, OutputMode.Update())( + // non-exist key test + AddData(inputData, InitInputRow("k1", "update", 37.0)), + AddData(inputData, InitInputRow("k2", "update", 40.0)), + AddData(inputData, InitInputRow("non-exist", "getOption", -1.0)), + CheckNewAnswer(("non-exist", "getOption", -1.0)), + AddData(inputData, InitInputRow("k1", "appendList", 37.0)), + AddData(inputData, InitInputRow("k2", "appendList", 40.0)), + AddData(inputData, InitInputRow("non-exist", "getList", -1.0)), + CheckNewAnswer(), + + AddData(inputData, InitInputRow("k1", "incCount", 37.0)), + AddData(inputData, InitInputRow("k2", "incCount", 40.0)), + AddData(inputData, InitInputRow("non-exist", "getCount", -1.0)), + CheckNewAnswer(("non-exist", "getCount", 0.0)), + AddData(inputData, InitInputRow("k2", "incCount", 40.0)), + AddData(inputData, InitInputRow("k2", "getCount", 40.0)), + CheckNewAnswer(("k2", "getCount", 2.0)), + + // test every row in initial State is processed + AddData(inputData, InitInputRow("init_1", "getOption", -1.0)), + CheckNewAnswer(("init_1", "getOption", 40.0)), + AddData(inputData, InitInputRow("init_2", "getOption", -1.0)), + CheckNewAnswer(("init_2", "getOption", 100.0)), + + AddData(inputData, InitInputRow("init_1", "getList", -1.0)), + CheckNewAnswer(("init_1", "getList", 40.0)), + AddData(inputData, InitInputRow("init_2", "getList", -1.0)), + CheckNewAnswer(("init_2", "getList", 100.0)), + + AddData(inputData, InitInputRow("init_1", "getCount", 40.0)), + CheckNewAnswer(("init_1", "getCount", 1.0)), + AddData(inputData, InitInputRow("init_2", "getCount", 100.0)), + CheckNewAnswer(("init_2", "getCount", 1.0)), + + // Update row with key in initial row will work + AddData(inputData, InitInputRow("init_1", "update", 50.0)), + AddData(inputData, InitInputRow("init_1", "getOption", -1.0)), + CheckNewAnswer(("init_1", "getOption", 50.0)), + AddData(inputData, InitInputRow("init_1", "remove", -1.0)), + AddData(inputData, InitInputRow("init_1", "getOption", -1.0)), + CheckNewAnswer(("init_1", "getOption", -1.0)), + + AddData(inputData, InitInputRow("init_1", "appendList", 50.0)), + AddData(inputData, InitInputRow("init_1", "getList", -1.0)), + CheckNewAnswer(("init_1", "getList", 50.0), ("init_1", "getList", 40.0)), + + AddData(inputData, InitInputRow("init_1", "incCount", 40.0)), + AddData(inputData, InitInputRow("init_1", "getCount", 40.0)), + CheckNewAnswer(("init_1", "getCount", 2.0)), + + // test remove + AddData(inputData, InitInputRow("k1", "remove", -1.0)), + AddData(inputData, InitInputRow("k1", "getOption", -1.0)), + CheckNewAnswer(("k1", "getOption", -1.0)), + + AddData(inputData, InitInputRow("init_1", "clearCount", -1.0)), + AddData(inputData, InitInputRow("init_1", "getCount", -1.0)), + CheckNewAnswer(("init_1", "getCount", 0.0)), + + AddData(inputData, InitInputRow("init_1", "clearList", -1.0)), + AddData(inputData, InitInputRow("init_1", "getList", -1.0)), + CheckNewAnswer() + ) + } + } + + test("transformWithStateWithInitialState -" + + " correctness test, processInitialState should only run once") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val initStateDf = createInitialDfForTest + val inputData = MemoryStream[InitInputRow] + val query = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf + ) + testStream(query, OutputMode.Update())( + AddData(inputData, InitInputRow("init_1", "add", 50.0)), + AddData(inputData, InitInputRow("init_2", "add", 60.0)), + AddData(inputData, InitInputRow("init_1", "add", 50.0)), + // If processInitialState was processed multiple times, + // following checks will fail + AddData(inputData, + InitInputRow("init_1", "getOption", -1.0), InitInputRow("init_2", "getOption", -1.0)), + CheckNewAnswer(("init_2", "getOption", 160.0), ("init_1", "getOption", 140.0)) + ) + } + } + + test("transformWithStateWithInitialState - batch should succeed") { + val inputData = Seq(InitInputRow("k1", "add", 37.0), InitInputRow("k1", "getOption", -1.0)) + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TimeoutMode.NoTimeouts(), + OutputMode.Append(), + createInitialDfForTest) + + val df = result.toDF() + checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF()) + } + + test("transformWithStateWithInitialState - " + + "cannot re-initialize state during initial state handling") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1", 50.0)).toDS() + .groupByKey(x => x._1).mapValues(x => x) + val inputData = MemoryStream[InitInputRow] + val query = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TimeoutMode.NoTimeouts(), + OutputMode.Append(), + initDf) + + testStream(query, OutputMode.Update())( + AddData(inputData, InitInputRow("k1", "add", 50.0)), + Execute { q => + val e = intercept[Exception] { + q.processAllAvailable() + } + checkError( + exception = e.getCause.asInstanceOf[SparkUnsupportedOperationException], + errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", + sqlState = Some("42802"), + parameters = Map("groupingKey" -> "init_1") + ) + } + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24b0d59c45c5..24e68e3db9d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -769,4 +769,24 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { } ) } + + test("transformWithStateWithInitialState - streaming with hdfsStateStoreProvider should fail") { + val inputData = MemoryStream[InitInputRow] + val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS() + .groupByKey(x => x._1) + .mapValues(x => x) + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf + ) + testStream(result, OutputMode.Update())( + AddData(inputData, InitInputRow("a", "add", -1.0)), + ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] { + (t: Throwable) => { + assert(t.getMessage.contains("not supported")) + } + } + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org