This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2871905e9ba8 [SPARK-54924][SS] State Rewriter to read state, transform
it and write new state
2871905e9ba8 is described below
commit 2871905e9ba824f0c1c3ae2397334e838fa92faf
Author: micheal-o <[email protected]>
AuthorDate: Wed Jan 7 12:52:24 2026 -0800
[SPARK-54924][SS] State Rewriter to read state, transform it and write new
state
### What changes were proposed in this pull request?
Introduce State Rewriter to rewrite the state stores for a stateful
streaming query. Read state, transform it, then write the new state.
### Why are the changes needed?
For offline state repartitioning and other future use case
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Updated existing tests
### Was this patch authored or co-authored using generative AI tooling?
Yes, Claude-4.5-opus
Closes #53703 from micheal-o/state_rewriter.
Authored-by: micheal-o <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 26 ++
.../streaming/state/OperatorStateMetadata.scala | 9 +
.../streaming/state/StatePartitionWriter.scala | 3 +-
.../execution/streaming/state/StateRewriter.scala | 404 +++++++++++++++++++++
.../v2/state/StateDataSourceTestBase.scala | 2 +-
...tatePartitionAllColumnFamiliesWriterSuite.scala | 324 +++++------------
6 files changed, 542 insertions(+), 226 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 18ef49f01bef..1f9c7321b50f 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5587,6 +5587,32 @@
},
"sqlState" : "42616"
},
+ "STATE_REWRITER_INVALID_CHECKPOINT" : {
+ "message" : [
+ "The state rewrite checkpoint location '<checkpointLocation>' is in an
invalid state."
+ ],
+ "subClass" : {
+ "MISSING_KEY_ENCODER_SPEC" : {
+ "message" : [
+ "Key state encoder spec is expected for column family
'<colFamilyName>' but was not found.",
+ "This is likely a bug, please report it."
+ ]
+ },
+ "MISSING_OPERATOR_METADATA" : {
+ "message" : [
+ "No stateful operator metadata was found for batch <batchId>.",
+ "Ensure that the checkpoint is for a stateful streaming query and
the query ran on a Spark version that supports operator metadata (Spark 4.0+)."
+ ]
+ },
+ "UNSUPPORTED_STATE_STORE_METADATA_VERSION" : {
+ "message" : [
+ "Unsupported state store metadata version encountered.",
+ "Only StateStoreMetadataV1 and StateStoreMetadataV2 are supported."
+ ]
+ }
+ },
+ "sqlState" : "55019"
+ },
"STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
"message" : [
"Failed to create column family with unsupported starting character and
name=<colFamilyName>."
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
index c34545216fda..6b2295da03b9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import java.io.{BufferedReader, InputStreamReader}
import java.nio.charset.StandardCharsets
+import scala.collection.immutable.ArraySeq
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
@@ -80,12 +81,17 @@ trait OperatorStateMetadata {
def version: Int
def operatorInfo: OperatorInfo
+
+ def stateStoresMetadata: Seq[StateStoreMetadata]
}
case class OperatorStateMetadataV1(
operatorInfo: OperatorInfoV1,
stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata
{
override def version: Int = 1
+
+ override def stateStoresMetadata: Seq[StateStoreMetadata] =
+ ArraySeq.unsafeWrapArray(stateStoreInfo)
}
case class OperatorStateMetadataV2(
@@ -93,6 +99,9 @@ case class OperatorStateMetadataV2(
stateStoreInfo: Array[StateStoreMetadataV2],
operatorPropertiesJson: String) extends OperatorStateMetadata {
override def version: Int = 2
+
+ override def stateStoresMetadata: Seq[StateStoreMetadata] =
+ ArraySeq.unsafeWrapArray(stateStoreInfo)
}
object OperatorStateMetadataUtils extends Logging {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
index aac13d3f69f9..3df97d3adc0e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
@@ -51,7 +51,7 @@ class StatePartitionAllColumnFamiliesWriter(
hadoopConf: Configuration,
partitionId: Int,
targetCpLocation: String,
- operatorId: Int,
+ operatorId: Long,
storeName: String,
currentBatchId: Long,
colFamilyToWriterInfoMap: Map[String,
StatePartitionWriterColumnFamilyInfo],
@@ -153,6 +153,7 @@ class StatePartitionAllColumnFamiliesWriter(
if (!stateStore.hasCommitted) {
stateStore.abort()
}
+ provider.close()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
new file mode 100644
index 000000000000..da28a3c907f7
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{SparkIllegalStateException, TaskContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
+import
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
+import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
+import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
+import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
+import
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants,
StreamingQueryCheckpointMetadata}
+import
org.apache.spark.sql.execution.streaming.state.{StatePartitionAllColumnFamiliesWriter,
StateSchemaCompatibilityChecker}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * State Rewriter is used to rewrite the state stores for a stateful streaming
query.
+ * It reads state from a checkpoint location, optionally applies
transformation to the state,
+ * and then writes the state back to a (possibly different) checkpoint
location for a new batch ID.
+ *
+ * Example use case is for offline state repartitioning.
+ * Can also be used to support other use cases.
+ *
+ * @param sparkSession The active Spark session.
+ * @param readBatchId The batch ID for reading state.
+ * @param writeBatchId The batch ID to which the (transformed) state will be
written.
+ * @param resolvedCheckpointLocation The resolved checkpoint path where state
will be written.
+ * @param hadoopConf Hadoop configuration for file system operations.
+ * @param readResolvedCheckpointLocation Optional separate checkpoint location
to read state from.
+ * If None, reads from
resolvedCheckpointLocation.
+ * @param transformFunc Optional transformation function applied to each
operator's state
+ * DataFrame. If None, state is written as-is.
+ * @param writeCheckpointMetadata Optional checkpoint metadata for the
resolvedCheckpointLocation.
+ * If None, will create a new one for
resolvedCheckpointLocation.
+ * Helps us to reuse already cached checkpoint
log entries,
+ * instead of starting from scratch.
+ */
+class StateRewriter(
+ sparkSession: SparkSession,
+ readBatchId: Long,
+ writeBatchId: Long,
+ resolvedCheckpointLocation: String,
+ hadoopConf: Configuration,
+ readResolvedCheckpointLocation: Option[String] = None,
+ transformFunc: Option[DataFrame => DataFrame] = None,
+ writeCheckpointMetadata: Option[StreamingQueryCheckpointMetadata] = None
+) extends Logging {
+ require(readResolvedCheckpointLocation.isDefined || readBatchId <
writeBatchId,
+ s"Read batch id $readBatchId must be less than write batch id
$writeBatchId " +
+ "when reading and writing to the same checkpoint location")
+
+ // If a different location was specified for reading state, use it.
+ // Else, use same location for reading and writing state.
+ private val checkpointLocationForRead =
+ readResolvedCheckpointLocation.getOrElse(resolvedCheckpointLocation)
+ private val stateRootLocation = new Path(
+ resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
+
+ def run(): Unit = {
+ logInfo(log"Starting state rewrite for " +
+ log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
+ log"readCheckpointLocation=" +
+ log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
+ log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
+ log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
+
+ val (_, timeTakenMs) = Utils.timeTakenMs {
+ runInternal()
+ }
+
+ logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms
for " +
+ log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
+ }
+
+ private def runInternal(): Unit = {
+ try {
+ val stateMetadataReader = new StateMetadataPartitionReader(
+ resolvedCheckpointLocation,
+ new SerializableConfiguration(hadoopConf),
+ readBatchId)
+
+ val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata
+ if (allOperatorsMetadata.isEmpty) {
+ // Its possible that the query is stateless
+ // or ran on older spark version without op metadata
+ throw StateRewriterErrors.missingOperatorMetadataError(
+ resolvedCheckpointLocation, readBatchId)
+ }
+
+ // Use the same conf in the offset log to create the store conf,
+ // to make sure the state is written with the right conf.
+ val (storeConf, sqlConf) = createConfsFromOffsetLog()
+ // SQLConf doesn't serialize properly (reader becomes null), so extract
as Map
+ val sqlConfEntries: Map[String, String] = sqlConf.getAllConfs
+
+ // A Hadoop Configuration can be about 10 KB, which is pretty big, so
broadcast it
+ val hadoopConfBroadcast =
+ SerializableConfiguration.broadcast(sparkSession.sparkContext,
hadoopConf)
+
+ // Do rewrite for each operator
+ // We can potentially parallelize this, but for now, do sequentially
+ allOperatorsMetadata.foreach { opMetadata =>
+ val stateStoresMetadata = opMetadata.stateStoresMetadata
+ assert(!stateStoresMetadata.isEmpty,
+ s"Operator ${opMetadata.operatorInfo.operatorName} has no state
stores")
+
+ val storeToSchemaFilesMap = getStoreToSchemaFilesMap(opMetadata)
+ val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
+
+ // Rewrite each state store of the operator
+ stateStoresMetadata.foreach { stateStoreMetadata =>
+ rewriteStore(
+ opMetadata,
+ stateStoreMetadata,
+ storeConf,
+ hadoopConfBroadcast,
+ storeToSchemaFilesMap(stateStoreMetadata.storeName),
+ stateVarsIfTws,
+ sqlConfEntries
+ )
+ }
+ }
+ } catch {
+ case e: Throwable =>
+ logError(log"State rewrite failed for " +
+ log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
+ log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
+ log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}", e)
+ throw e
+ }
+ }
+
+ private def rewriteStore(
+ opMetadata: OperatorStateMetadata,
+ stateStoreMetadata: StateStoreMetadata,
+ storeConf: StateStoreConf,
+ hadoopConfBroadcast: Broadcast[SerializableConfiguration],
+ storeSchemaFiles: List[Path],
+ stateVarsIfTws: Map[String, TransformWithStateVariableInfo],
+ sqlConfEntries: Map[String, String]
+ ): Unit = {
+ // Read state
+ val stateDf = sparkSession.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, checkpointLocationForRead)
+ .option(StateSourceOptions.BATCH_ID, readBatchId)
+ .option(StateSourceOptions.OPERATOR_ID,
opMetadata.operatorInfo.operatorId)
+ .option(StateSourceOptions.STORE_NAME, stateStoreMetadata.storeName)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ .load()
+
+ // Run the caller state transformation func if provided
+ // Otherwise, use the state as is
+ val updatedStateDf = transformFunc.map(func =>
func(stateDf)).getOrElse(stateDf)
+ require(updatedStateDf.schema == stateDf.schema,
+ s"State transformation function must return a DataFrame with the same
schema " +
+ s"as the original state DataFrame. Original schema: ${stateDf.schema},
" +
+ s"Updated schema: ${updatedStateDf.schema}")
+
+ val schemaProvider = createStoreSchemaProviderIfTWS(
+ opMetadata.operatorInfo.operatorName,
+ storeSchemaFiles
+ )
+ val writerColFamilyInfoMap = getWriterColFamilyInfoMap(
+ opMetadata.operatorInfo.operatorId,
+ stateStoreMetadata,
+ storeSchemaFiles,
+ stateVarsIfTws
+ )
+
+ logInfo(log"Writing new state for " +
+ log"operator=${MDC(OP_TYPE, opMetadata.operatorInfo.operatorName)}, " +
+ log"stateStore=${MDC(STATE_NAME, stateStoreMetadata.storeName)}, " +
+ log"numColumnFamilies=${MDC(COUNT, writerColFamilyInfoMap.size)}, " +
+ log"numSchemaFiles=${MDC(NUM_FILES, storeSchemaFiles.size)}, " +
+ log"for new batch=${MDC(BATCH_ID, writeBatchId)}, " +
+ log"for checkpoint=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
+
+ // Write state for each partition on the executor.
+ // Setting this as local val,
+ // to avoid serializing the entire Rewriter object per partition.
+ val targetCheckpointLocation = resolvedCheckpointLocation
+ val currentBatchId = writeBatchId
+ updatedStateDf.queryExecution.toRdd.foreachPartition { partitionIter =>
+ // Recreate SQLConf on executor from serialized entries
+ val executorSqlConf = new SQLConf()
+ sqlConfEntries.foreach { case (k, v) => executorSqlConf.setConfString(k,
v) }
+
+ val partitionWriter = new StatePartitionAllColumnFamiliesWriter(
+ storeConf,
+ hadoopConfBroadcast.value.value,
+ TaskContext.get().partitionId(),
+ targetCheckpointLocation,
+ opMetadata.operatorInfo.operatorId,
+ stateStoreMetadata.storeName,
+ currentBatchId,
+ writerColFamilyInfoMap,
+ opMetadata.operatorInfo.operatorName,
+ schemaProvider,
+ executorSqlConf
+ )
+
+ partitionWriter.write(partitionIter)
+ }
+ }
+
+ /** Create the store and sql confs from the conf written in the offset log */
+ private def createConfsFromOffsetLog(): (StateStoreConf, SQLConf) = {
+ val offsetLog = writeCheckpointMetadata.getOrElse(
+ new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation)).offsetLog
+
+ // We want to use the same confs written in the offset log for the new
batch
+ val offsetSeq = offsetLog.get(writeBatchId)
+ require(offsetSeq.isDefined, s"Offset seq must be present for the new
batch $writeBatchId")
+ val metadata = offsetSeq.get.metadataOpt
+ require(metadata.isDefined, s"Metadata must be present for the new batch
$writeBatchId")
+
+ val clonedSqlConf = sparkSession.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(metadata.get, clonedSqlConf)
+ (StateStoreConf(clonedSqlConf), clonedSqlConf)
+ }
+
+ /** Get the map of state store name to schema files, for an operator */
+ private def getStoreToSchemaFilesMap(
+ opMetadata: OperatorStateMetadata): Map[String, List[Path]] = {
+ opMetadata.stateStoresMetadata.map { storeMetadata =>
+ val schemaFiles = storeMetadata match {
+ // No schema files for v1. It has a fixed/known schema file path
+ case _: StateStoreMetadataV1 => List.empty[Path]
+ case v2: StateStoreMetadataV2 => v2.stateSchemaFilePaths.map(new
Path(_))
+ case _ =>
+ throw StateRewriterErrors.unsupportedStateStoreMetadataVersionError(
+ resolvedCheckpointLocation)
+ }
+ storeMetadata.storeName -> schemaFiles
+ }.toMap
+ }
+
+ private def getWriterColFamilyInfoMap(
+ operatorId: Long,
+ storeMetadata: StateStoreMetadata,
+ schemaFiles: List[Path],
+ twsStateVariables: Map[String, TransformWithStateVariableInfo] =
Map.empty
+ ): Map[String, StatePartitionWriterColumnFamilyInfo] = {
+ getLatestColFamilyToSchemaMap(operatorId, storeMetadata, schemaFiles)
+ .map { case (colFamilyName, schema) =>
+ colFamilyName -> StatePartitionWriterColumnFamilyInfo(schema,
+ useMultipleValuesPerKey = twsStateVariables.get(colFamilyName)
+ .map(_.stateVariableType ==
StateVariableType.ListState).getOrElse(false))
+ }
+ }
+
+ private def getLatestColFamilyToSchemaMap(
+ operatorId: Long,
+ storeMetadata: StateStoreMetadata,
+ schemaFiles: List[Path]): Map[String, StateStoreColFamilySchema] = {
+ val storeId = new StateStoreId(
+ stateRootLocation,
+ operatorId,
+ StateStore.PARTITION_ID_TO_CHECK_SCHEMA,
+ storeMetadata.storeName)
+ // using a placeholder runId since we are not running a streaming query
+ val providerId = new StateStoreProviderId(storeId, queryRunId =
UUID.randomUUID())
+ val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf,
+ oldSchemaFilePaths = schemaFiles)
+ // Read the latest state schema from the provided path for v2 or from the
dedicated path
+ // for v1
+ manager
+ .readSchemaFile()
+ .map { schema =>
+ schema.colFamilyName -> createKeyEncoderSpecIfAbsent(schema,
storeMetadata) }.toMap
+ }
+
+ private def createKeyEncoderSpecIfAbsent(
+ colFamilySchema: StateStoreColFamilySchema,
+ storeMetadata: StateStoreMetadata): StateStoreColFamilySchema = {
+ colFamilySchema.keyStateEncoderSpec match {
+ case Some(encoderSpec) => colFamilySchema
+ case None if storeMetadata.isInstanceOf[StateStoreMetadataV1] =>
+ // Create the spec if missing for v1 metadata
+ if (storeMetadata.numColsPrefixKey > 0) {
+ colFamilySchema.copy(keyStateEncoderSpec =
+ Some(PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema,
+ storeMetadata.numColsPrefixKey)))
+ } else {
+ colFamilySchema.copy(keyStateEncoderSpec =
+ Some(NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)))
+ }
+ case _ =>
+ // Key encoder spec is expected in v2 metadata
+ throw StateRewriterErrors.missingKeyEncoderSpecError(
+ resolvedCheckpointLocation, colFamilySchema.colFamilyName)
+ }
+ }
+
+ private def getStateVariablesIfTWS(
+ opMetadata: OperatorStateMetadata): Map[String,
TransformWithStateVariableInfo] = {
+ if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES
+ .contains(opMetadata.operatorInfo.operatorName)) {
+ val operatorProperties = TransformWithStateOperatorProperties.fromJson(
+
opMetadata.asInstanceOf[OperatorStateMetadataV2].operatorPropertiesJson)
+ operatorProperties.stateVariables.map(s => s.stateName -> s).toMap
+ } else {
+ Map.empty
+ }
+ }
+
+ // Needed only for schema evolution for TWS
+ private def createStoreSchemaProviderIfTWS(
+ opName: String,
+ schemaFiles: List[Path]): Option[StateSchemaProvider] = {
+ if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(opName))
{
+ val schemaMetadata = StateSchemaMetadata.createStateSchemaMetadata(
+ stateRootLocation, hadoopConf, schemaFiles.map(_.toString))
+ Some(new InMemoryStateSchemaProvider(schemaMetadata))
+ } else {
+ None
+ }
+ }
+}
+
+/**
+ * Errors thrown by StateRewriter.
+ */
+private[state] object StateRewriterErrors {
+ def missingKeyEncoderSpecError(
+ checkpointLocation: String,
+ colFamilyName: String): StateRewriterInvalidCheckpointError = {
+ new StateRewriterMissingKeyEncoderSpecError(checkpointLocation,
colFamilyName)
+ }
+
+ def missingOperatorMetadataError(
+ checkpointLocation: String,
+ batchId: Long): StateRewriterInvalidCheckpointError = {
+ new StateRewriterMissingOperatorMetadataError(checkpointLocation, batchId)
+ }
+
+ def unsupportedStateStoreMetadataVersionError(
+ checkpointLocation: String): StateRewriterInvalidCheckpointError = {
+ new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation)
+ }
+}
+
+/**
+ * Base class for exceptions thrown when the checkpoint location is in an
invalid state
+ * for state rewriting.
+ */
+private[state] abstract class StateRewriterInvalidCheckpointError(
+ checkpointLocation: String,
+ subClass: String,
+ messageParameters: Map[String, String],
+ cause: Throwable = null)
+ extends SparkIllegalStateException(
+ errorClass = s"STATE_REWRITER_INVALID_CHECKPOINT.$subClass",
+ messageParameters = Map("checkpointLocation" -> checkpointLocation) ++
messageParameters,
+ cause = cause)
+
+private[state] class StateRewriterMissingKeyEncoderSpecError(
+ checkpointLocation: String,
+ colFamilyName: String)
+ extends StateRewriterInvalidCheckpointError(
+ checkpointLocation,
+ subClass = "MISSING_KEY_ENCODER_SPEC",
+ messageParameters = Map("colFamilyName" -> colFamilyName))
+
+private[state] class StateRewriterMissingOperatorMetadataError(
+ checkpointLocation: String,
+ batchId: Long)
+ extends StateRewriterInvalidCheckpointError(
+ checkpointLocation,
+ subClass = "MISSING_OPERATOR_METADATA",
+ messageParameters = Map("batchId" -> batchId.toString))
+
+private[state] class StateRewriterUnsupportedStoreMetadataVersionError(
+ checkpointLocation: String)
+ extends StateRewriterInvalidCheckpointError(
+ checkpointLocation,
+ subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION",
+ messageParameters = Map.empty)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
index 6a1f66262d5f..d8a3bbb65af2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
@@ -666,7 +666,7 @@ object SessionWindowTestUtils {
*/
object StreamStreamJoinTestUtils {
// All state store names from SymmetricHashJoinStateManager
- private val allStoreNames: Seq[String] =
+ val allStoreNames: Seq[String] =
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
// Column family names for keyToNumValues stores (derived from
allStateStoreNames)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
index 9501e4e9e36b..e495db499bfe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -20,10 +20,8 @@ import java.io.File
import java.sql.Timestamp
import java.time.Duration
-import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
-import
org.apache.spark.sql.execution.datasources.v2.state.{CompositeKeyAggregationTestUtils,
DropDuplicatesTestUtils, FlatMapGroupsWithStateTestUtils,
SessionWindowTestUtils, SimpleAggregationTestUtils, StateDataSourceTestBase,
StateSourceOptions, StreamStreamJoinTestUtils}
+import
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceTestBase,
StateSourceOptions, StreamStreamJoinTestUtils}
import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryCheckpointMetadata}
@@ -34,7 +32,6 @@ import org.apache.spark.sql.streaming.{InputEvent,
ListStateTTLProcessor, MapInp
import org.apache.spark.sql.streaming.util.{StreamManualClock,
TTLProcessorUtils}
import org.apache.spark.sql.streaming.util.{EventTimeTimerProcessor,
MultiStateVarProcessor, MultiStateVarProcessorTestUtils, TimerTestUtils}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.SerializableConfiguration
/**
* Test suite for StatePartitionAllColumnFamiliesWriter.
@@ -51,68 +48,33 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2")
}
- /**
- * Helper method to create a StateSchemaProvider from column family schema
map.
- */
- private def createStateSchemaProvider(
- columnFamilyToSchemaMap: Map[String,
StatePartitionWriterColumnFamilyInfo]
- ): StateSchemaProvider = {
- val testSchemaProvider = new TestStateSchemaProvider()
- columnFamilyToSchemaMap.foreach { case (cfName, cfInfo) =>
- testSchemaProvider.captureSchema(
- colFamilyName = cfName,
- keySchema = cfInfo.schema.keySchema,
- valueSchema = cfInfo.schema.valueSchema,
- keySchemaId = cfInfo.schema.keySchemaId,
- valueSchemaId = cfInfo.schema.valueSchemaId
- )
- }
- testSchemaProvider
- }
-
/**
* Common helper method to perform round-trip test: read state bytes from
source,
* write to target, and verify target matches source.
*
* @param sourceDir Source checkpoint directory
* @param targetDir Target checkpoint directory
- * @param columnFamilyToSchemaMap Map of column family names to their schemas
- * @param storeName Optional store name (for stream-stream join which has
multiple stores)
- * @param columnFamilyToSelectExprs Map of column family names to custom
selectExprs
- * @param columnFamilyToStateSourceOptions Map of column family names to
state source options
+ * @param storeToColumnFamilies Optional store name to its column families
+ * @param storeToColumnFamilyToSelectExprs Map store name to per column
family custom selectExprs
+ * @param storeToColumnFamilyToStateSourceOptions Map store name to per
column family
+ * state source options
*/
private def performRoundTripTest(
sourceDir: String,
targetDir: String,
- columnFamilyToSchemaMap: Map[String,
StatePartitionWriterColumnFamilyInfo],
- storeName: Option[String] = None,
- columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
- columnFamilyToStateSourceOptions: Map[String, Map[String, String]] =
Map.empty,
+ storeToColumnFamilies: Map[String, List[String]] =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
List(StateStore.DEFAULT_COL_FAMILY_NAME)),
+ storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]]
= Map.empty,
+ storeToColumnFamilyToStateSourceOptions: Map[String, Map[String,
Map[String, String]]] =
+ Map.empty,
operatorName: String): Unit = {
-
- val columnFamiliesToValidate: Seq[String] = if
(columnFamilyToSchemaMap.size > 1) {
- columnFamilyToSchemaMap.keys.toSeq
- } else {
- Seq(StateStore.DEFAULT_COL_FAMILY_NAME)
- }
-
- // Step 1: Read from source using AllColumnFamiliesReader (raw bytes)
- val sourceBytesReader = spark.read
- .format("statestore")
- .option(StateSourceOptions.PATH, sourceDir)
- .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
- val sourceBytesData = (storeName match {
- case Some(name) =>
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
- case None => sourceBytesReader
- }).load()
-
- // Verify schema of raw bytes
- val schema = sourceBytesData.schema
- assert(schema.fieldNames === Array(
- "partition_key", "key_bytes", "value_bytes", "column_family_name"))
-
- // Step 2: Write raw bytes to target checkpoint location
val hadoopConf = spark.sessionState.newHadoopConf()
+ val sourceCpLocation = StreamingUtils.resolvedCheckpointLocation(
+ hadoopConf, sourceDir)
+ val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+ spark, sourceCpLocation)
+ val readBatchId = sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+
val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
hadoopConf, targetDir)
val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
@@ -120,67 +82,56 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
// increase offsetCheckpoint
val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
- val currentBatchId = lastBatch + 1
- targetCheckpointMetadata.offsetLog.add(currentBatchId, targetOffsetSeq)
-
- val storeConf: StateStoreConf = StateStoreConf(spark.sessionState.conf)
- val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
-
- // Create StateSchemaProvider if needed (for Avro encoding)
- val stateSchemaProvider = if (storeConf.stateStoreEncodingFormat ==
"avro") {
- Some(createStateSchemaProvider(columnFamilyToSchemaMap))
- } else {
- None
- }
- val baseConfs: Map[String, String] = spark.sessionState.conf.getAllConfs
- val putPartitionFunc: Iterator[InternalRow] => Unit = partition => {
- val newConf = new SQLConf
- baseConfs.foreach { case (k, v) =>
- newConf.setConfString(k, v)
- }
- val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
- storeConf,
- serializableHadoopConf.value,
- TaskContext.getPartitionId(),
- targetCpLocation,
- 0,
- storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
- currentBatchId,
- columnFamilyToSchemaMap,
- operatorName,
- stateSchemaProvider,
- newConf
- )
- allCFWriter.write(partition)
- }
- sourceBytesData.queryExecution.toRdd.foreachPartition(putPartitionFunc)
+ val writeBatchId = lastBatch + 1
+ targetCheckpointMetadata.offsetLog.add(writeBatchId, targetOffsetSeq)
+
+ val rewriter = new StateRewriter(
+ spark,
+ readBatchId,
+ writeBatchId,
+ targetCpLocation,
+ hadoopConf,
+ readResolvedCheckpointLocation = Some(sourceCpLocation),
+ transformFunc = None,
+ writeCheckpointMetadata = Some(targetCheckpointMetadata)
+ )
+ rewriter.run()
// Commit to commitLog
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(currentBatchId, latestCommit)
- val versionToCheck = currentBatchId + 1
- val storeNamePath = s"state/0/0${storeName.fold("")("/" + _)}"
- assert(!checkpointFileExists(new File(targetDir, storeNamePath),
versionToCheck, ".changelog"))
- assert(checkpointFileExists(new File(targetDir, storeNamePath),
versionToCheck, ".zip"))
+ targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val versionToCheck = writeBatchId + 1
+
+ storeToColumnFamilies.foreach { case (storeName, columnFamilies) =>
+ val storeNamePath = if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
+ "state/0/0"
+ } else {
+ s"state/0/0/$storeName"
+ }
+ assert(!checkpointFileExists(new File(targetDir, storeNamePath),
+ versionToCheck, ".changelog"))
+ assert(checkpointFileExists(new File(targetDir, storeNamePath),
versionToCheck, ".zip"))
- // Step 3: Validate by reading from both source and target using normal
reader"
- // Default selectExprs for most column families
- val defaultSelectExprs = Seq("key", "value", "partition_id")
+ // Validate by reading from both source and target using normal reader"
+ // Default selectExprs for most column families
+ val defaultSelectExprs = Seq("key", "value", "partition_id")
- columnFamiliesToValidate
+ columnFamilies
// filtering out "default" for TWS operator because it doesn't contain
any data
.filter(cfName => !(cfName == StateStore.DEFAULT_COL_FAMILY_NAME &&
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName)
))
.foreach { cfName =>
- val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName,
defaultSelectExprs)
- val readerOptions = columnFamilyToStateSourceOptions.getOrElse(cfName,
Map.empty)
+ val selectExprs =
storeToColumnFamilyToSelectExprs.getOrElse(storeName, Map.empty)
+ .getOrElse(cfName, defaultSelectExprs)
+ val readerOptions =
storeToColumnFamilyToStateSourceOptions.getOrElse(storeName, Map.empty)
+ .getOrElse(cfName, Map.empty)
def readNormalData(dir: String): Array[Row] = {
var reader = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, dir)
- .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+ .option(StateSourceOptions.STORE_NAME, storeName)
readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
reader.load()
.selectExpr(selectExprs: _*)
@@ -192,6 +143,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
validateDataMatches(sourceNormalData, targetNormalData)
}
+ }
}
/**
@@ -308,15 +260,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
)
- // Step 2: Define schemas based on state version
- val metadata =
SimpleAggregationTestUtils.getSchemasWithMetadata(stateVersion)
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME
)
}
@@ -349,15 +296,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
)
- // Step 2: Define schemas based on state version for composite key
- val metadata =
CompositeKeyAggregationTestUtils.getSchemasWithMetadata(stateVersion)
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME
)
}
@@ -396,36 +338,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
// Step 2: Test all 4 state stores created by stream-stream join
- // Test keyToNumValues stores (both left and right)
- StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL.foreach { storeName
=>
- val metadata =
StreamStreamJoinTestUtils.getKeyToNumValuesSchemasWithMetadata()
-
- // Perform round-trip test using common helper
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema,
metadata.encoderSpec),
- storeName = Some(storeName),
- operatorName =
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
- )
- }
-
- // Test keyWithIndexToValue stores (both left and right)
- StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL.foreach { storeName =>
- val metadata =
-
StreamStreamJoinTestUtils.getKeyWithIndexToValueSchemasWithMetadata(stateVersion)
-
- // Perform round-trip test using common helper
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema,
metadata.encoderSpec),
- storeName = Some(storeName),
- operatorName =
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
- )
- }
+ val storeToColumnFamilies = StreamStreamJoinTestUtils.allStoreNames
+ .map(s => s -> List(StateStore.DEFAULT_COL_FAMILY_NAME)).toMap
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies,
+ operatorName =
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
+ )
}
}
}
@@ -458,15 +379,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
CheckLastBatch(("a", 1, 0, false))
)
- // Step 2: Define schemas for flatMapGroupsWithState
- val metadata =
FlatMapGroupsWithStateTestUtils.getSchemasWithMetadata(stateVersion)
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName =
StatefulOperatorsUtils.FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME
)
}
@@ -474,7 +390,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
-
/**
* Helper method to build timer column family schemas and options for
* RunningCountStatefulProcessorWithProcTimeTimer and EventTimeTimerProcessor
@@ -564,16 +479,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
CheckAnswer(("a", 1))
)
- // Step 2: Define schemas for dropDuplicatesWithinWatermark
- val metadata =
-
DropDuplicatesTestUtils.getDropDuplicatesWithinWatermarkSchemasWithMetadata()
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName =
StatefulOperatorsUtils.DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME
)
}
@@ -595,16 +504,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
CheckAnswer(("a", 1))
)
- // Step 2: Define schemas for dropDuplicates with column specified
- val metadata =
-
DropDuplicatesTestUtils.getDropDuplicatesWithColumnSchemasWithMetadata()
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME
)
}
@@ -629,16 +532,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
StopStream
)
- // Step 2: Define schemas for session window aggregation
- val (keySchema, valueSchema) = SessionWindowTestUtils.getSchemas()
- // Session window aggregation uses prefix key scanning where
sessionId is the prefix
- val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1)
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(keySchema, valueSchema,
keyStateEncoderSpec),
operatorName =
StatefulOperatorsUtils.SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME
)
}
@@ -660,15 +557,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
assertNumStateRows(total = 6, updated = 6)
)
- // Step 2: Define schemas for dropDuplicates (state version 2)
- val metadata =
DropDuplicatesTestUtils.getDropDuplicatesSchemasWithMetadata()
-
// Perform round-trip test using common helper
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- createSingleColumnFamilySchemaMap(
- metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME
)
}
@@ -713,15 +605,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
runQuery(sourceDir.getAbsolutePath, roundsOfData = 2)
runQuery(targetDir.getAbsolutePath, roundsOfData = 1)
- val allColFamilyNames =
StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL ++
- StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL
+ val allColFamilyNames =
StreamStreamJoinTestUtils.allStoreNames.toList
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- getJoinV3ColumnSchemaMap(),
- columnFamilyToStateSourceOptions = allColFamilyNames.map {
- colName => colName -> Map(StateSourceOptions.STORE_NAME ->
colName)
- }.toMap,
+ storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME ->
allColFamilyNames),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> allColFamilyNames.map {
+ cfName => cfName -> Map(StateSourceOptions.STORE_NAME ->
cfName)
+ }.toMap),
operatorName =
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
)
}
@@ -770,14 +662,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
runQuery(targetDir.getAbsolutePath, 1)
val schemas =
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
- val columnFamilyToSchemaMap = schemas.map { case (cfName,
metadata) =>
- cfName -> createColFamilyInfo(
- metadata.keySchema,
- metadata.valueSchema,
- metadata.encoderSpec,
- cfName,
- metadata.useMultipleValuePerKey)
- }
val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
.getColumnFamilyToSelectExprs()
@@ -799,9 +683,11 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- columnFamilyToSchemaMap,
- columnFamilyToSelectExprs = columnFamilyToSelectExprs,
- columnFamilyToStateSourceOptions =
columnFamilyToStateSourceOptions,
+ storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME ->
schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
@@ -842,9 +728,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- schemaMap,
- columnFamilyToSelectExprs = selectExprs,
- columnFamilyToStateSourceOptions = stateSourceOptions,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
@@ -890,9 +779,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- schemaMap,
- columnFamilyToSelectExprs = selectExprs,
- columnFamilyToStateSourceOptions = sourceOptions,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
@@ -933,14 +825,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
val schemas =
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
- val columnFamilyToSchemaMap = schemas.map { case (cfName,
metadata) =>
- cfName -> createColFamilyInfo(
- metadata.keySchema,
- metadata.valueSchema,
- metadata.encoderSpec,
- cfName,
- metadata.useMultipleValuePerKey)
- }
val columnFamilyToSelectExprs = Map(
TTLProcessorUtils.LIST_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
@@ -963,9 +847,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- columnFamilyToSchemaMap,
- columnFamilyToSelectExprs = columnFamilyToSelectExprs,
- columnFamilyToStateSourceOptions =
columnFamilyToStateSourceOptions,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
@@ -1006,14 +893,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
- val columnFamilyToSchemaMap = schemas.map { case (cfName,
metadata) =>
- cfName -> createColFamilyInfo(
- metadata.keySchema,
- metadata.valueSchema,
- metadata.encoderSpec,
- cfName,
- metadata.useMultipleValuePerKey)
- }
val columnFamilyToSelectExprs = Map(
TTLProcessorUtils.MAP_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
@@ -1027,9 +906,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- columnFamilyToSchemaMap,
- columnFamilyToSelectExprs = columnFamilyToSelectExprs,
- columnFamilyToStateSourceOptions =
columnFamilyToStateSourceOptions,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
@@ -1071,14 +953,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
)
val schemas =
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
- val columnFamilyToSchemaMap = schemas.map { case (cfName,
metadata) =>
- cfName -> createColFamilyInfo(
- metadata.keySchema,
- metadata.valueSchema,
- metadata.encoderSpec,
- cfName,
- metadata.useMultipleValuePerKey)
- }
val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
@@ -1087,8 +961,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
performRoundTripTest(
sourceDir.getAbsolutePath,
targetDir.getAbsolutePath,
- columnFamilyToSchemaMap,
- columnFamilyToStateSourceOptions =
columnFamilyToStateSourceOptions,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]