This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2fb3c3571cc0 [SPARK-48891][SS] Refactor StateSchemaCompatibilityChecker to unify all state schema formats 2fb3c3571cc0 is described below commit 2fb3c3571cc026c080449439327e22a1e1b342cf Author: Anish Shrigondekar <anish.shrigonde...@databricks.com> AuthorDate: Mon Jul 22 08:23:10 2024 +0900 [SPARK-48891][SS] Refactor StateSchemaCompatibilityChecker to unify all state schema formats ### What changes were proposed in this pull request? Refactor StateSchemaCompatibilityChecker to unify all state schema formats ### Why are the changes needed? Needed to integrate future changes around state data source reader and schema evolution and consolidate these changes - Consolidates all state schema reader/writers in one place - Consolidates all validation logic through the same API ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests ``` 12:38:45.481 WARN org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityCheckerSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.streaming.state.StateSchemaCompatibilityCheckerSuite, threads: rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), ForkJoinPool.commonPool-worker-2 (daemon=true), shuffle-boss-6-1 (daemon=true), ForkJoinPool.commonPool-worker-1 (daemon=true) ===== [info] Run completed in 12 seconds, 565 milliseconds. [info] Total number of tests run: 30 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 30, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #47359 from anishshri-db/task/SPARK-48891. Authored-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../datasources/v2/state/StateDataSource.scala | 3 +- .../v2/state/StreamStreamJoinStateHelper.scala | 6 +- .../streaming/FlatMapGroupsWithStateExec.scala | 8 +- .../execution/streaming/IncrementalExecution.scala | 7 +- ...ala => StateStoreColumnFamilySchemaUtils.scala} | 42 +--- .../streaming/StatefulProcessorHandleImpl.scala | 20 +- .../streaming/StreamingSymmetricHashJoinExec.scala | 10 +- .../streaming/TransformWithStateExec.scala | 39 +--- .../execution/streaming/state/SchemaHelper.scala | 227 +++++++++++---------- .../state/StateSchemaCompatibilityChecker.scala | 173 ++++++++++------ .../streaming/state/StateSchemaV3File.scala | 99 --------- .../execution/streaming/statefulOperators.scala | 56 +++-- .../sql/execution/streaming/streamingLimits.scala | 10 +- .../StateSchemaCompatibilityCheckerSuite.scala | 158 +++++++++++--- .../sql/streaming/TransformWithStateSuite.scala | 98 ++------- .../TransformWithValueStateTTLSuite.scala | 25 ++- 16 files changed, 490 insertions(+), 491 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index e2c5499fe439..4e7f2f1c41e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -91,7 +91,8 @@ class StateDataSource extends TableProvider with DataSourceRegister { partitionId, sourceOptions.storeName) val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf) - manager.readSchemaFile() + val stateSchema = manager.readSchemaFile().head + (stateSchema.keySchema, stateSchema.valueSchema) } if (sourceOptions.readChangeFeed) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala index 7b08c289fcc4..1a04d24f0048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala @@ -67,12 +67,14 @@ object StreamStreamJoinStateHelper { val newHadoopConf = session.sessionState.newHadoopConf() + // read the key schema from the keyToNumValues store for the join keys val manager = new StateSchemaCompatibilityChecker(providerIdForKeyToNumValues, newHadoopConf) - val (keySchema, _) = manager.readSchemaFile() + val keySchema = manager.readSchemaFile().head.keySchema + // read the value schema from the keyWithIndexToValue store for the values val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue, newHadoopConf) - val (_, valueSchema) = manager2.readSchemaFile() + val valueSchema = manager2.readSchemaFile().head.valueSchema val maybeMatchedColumn = valueSchema.last diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index e2ef656fce75..3ee1fc1db71f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -192,9 +192,11 @@ trait FlatMapGroupsWithStateExecBase override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int): Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - groupingAttributes.toStructType, stateManager.stateSchema, session.sessionState) + stateSchemaVersion: Int): List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + groupingAttributes.toStructType, stateManager.stateSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 772c26ac7667..722a3bd86b7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -201,11 +201,12 @@ class IncrementalExecution( // filepath, and write this path out in the OperatorStateMetadata file case statefulOp: StatefulOperator if isFirstBatch => val stateSchemaVersion = statefulOp match { - case _: TransformWithStateExec => sparkSession.sessionState.conf. - getConf(SQLConf.STREAMING_TRANSFORM_WITH_STATE_OP_STATE_SCHEMA_VERSION) + case _: TransformWithStateExec => + sparkSession.sessionState.conf. + getConf(SQLConf.STREAMING_TRANSFORM_WITH_STATE_OP_STATE_SCHEMA_VERSION) case _ => STATE_SCHEMA_DEFAULT_VERSION } - val stateSchemaPaths = statefulOp. + val schemaValidationResult = statefulOp. validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion) // write out the state schema paths to the metadata file statefulOp match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala similarity index 61% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 513f941bdfbd..99229c6132eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -19,52 +19,32 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} -trait ColumnFamilySchemaUtils { - def getValueStateSchema[T]( - stateName: String, - keyEncoder: ExpressionEncoder[Any], - valEncoder: Encoder[T], - hasTtl: Boolean): ColumnFamilySchema - - def getListStateSchema[T]( - stateName: String, - keyEncoder: ExpressionEncoder[Any], - valEncoder: Encoder[T], - hasTtl: Boolean): ColumnFamilySchema - - def getMapStateSchema[K, V]( - stateName: String, - keyEncoder: ExpressionEncoder[Any], - userKeyEnc: Encoder[K], - valEncoder: Encoder[V], - hasTtl: Boolean): ColumnFamilySchema -} -object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { +object StateStoreColumnFamilySchemaUtils { def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], - hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + hasTtl: Boolean): StateStoreColFamilySchema = { + StateStoreColFamilySchema( stateName, keyEncoder.schema, getValueSchemaWithTTL(valEncoder.schema, hasTtl), - NoPrefixKeyStateEncoderSpec(keyEncoder.schema)) + Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) } def getListStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], valEncoder: Encoder[T], - hasTtl: Boolean): ColumnFamilySchemaV1 = { - new ColumnFamilySchemaV1( + hasTtl: Boolean): StateStoreColFamilySchema = { + StateStoreColFamilySchema( stateName, keyEncoder.schema, getValueSchemaWithTTL(valEncoder.schema, hasTtl), - NoPrefixKeyStateEncoderSpec(keyEncoder.schema)) + Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))) } def getMapStateSchema[K, V]( @@ -72,13 +52,13 @@ object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils { keyEncoder: ExpressionEncoder[Any], userKeyEnc: Encoder[K], valEncoder: Encoder[V], - hasTtl: Boolean): ColumnFamilySchemaV1 = { + hasTtl: Boolean): StateStoreColFamilySchema = { val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema) - new ColumnFamilySchemaV1( + StateStoreColFamilySchema( stateName, compositeKeySchema, getValueSchemaWithTTL(valEncoder.schema, hasTtl), - PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1), + Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Some(userKeyEnc.schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 3031faa35b2d..44e2e6838f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -300,18 +300,16 @@ class StatefulProcessorHandleImpl( class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) { - private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1 - // Because this is only happening on the driver side, there is only // one task modifying and accessing this map at a time - private[sql] val columnFamilySchemas: mutable.Map[String, ColumnFamilySchema] = - new mutable.HashMap[String, ColumnFamilySchema]() + private[sql] val columnFamilySchemas: mutable.Map[String, StateStoreColFamilySchema] = + new mutable.HashMap[String, StateStoreColFamilySchema]() - def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap + def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] @@ -322,7 +320,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi valEncoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ValueState[T]] @@ -330,7 +328,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] @@ -341,7 +339,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi valEncoder: Encoder[T], ttlConfig: TTLConfig): ListState[T] = { verifyStateVarOperations("get_list_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[ListState[T]] @@ -352,7 +350,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] @@ -364,7 +362,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi valEncoder: Encoder[V], ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state", PRE_INIT) - val colFamilySchema = columnFamilySchemaUtils. + val colFamilySchema = StateStoreColumnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) null.asInstanceOf[MapState[K, V]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index ea275a28780e..a303d4db66a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -249,7 +249,7 @@ case class StreamingSymmetricHashJoinExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int): Array[String] = { + stateSchemaVersion: Int): List[StateSchemaValidationResult] = { var result: Map[String, (StructType, StructType)] = Map.empty // get state schema for state stores on left side of the join result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide, @@ -260,10 +260,12 @@ case class StreamingSymmetricHashJoinExec( right.output, rightKeys, stateFormatVersion) // validate and maybe evolve schema for all state stores across both sides of the join - result.iterator.flatMap { case (stateStoreName, (keySchema, valueSchema)) => + result.map { case (stateStoreName, (keySchema, valueSchema)) => + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keySchema, valueSchema)) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keySchema, valueSchema, session.sessionState, storeName = stateStoreName) - }.toArray + newStateSchema, session.sessionState, stateSchemaVersion, storeName = stateStoreName) + }.toList } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 9014878178a1..a4d525ad13fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -113,7 +113,7 @@ case class TransformWithStateExec( * Fetching the columnFamilySchemas from the StatefulProcessorHandle * after init is called. */ - private def getColFamilySchemas(): Map[String, ColumnFamilySchema] = { + private def getColFamilySchemas(): Map[String, StateStoreColFamilySchema] = { val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas closeProcessorHandle() columnFamilySchemas @@ -380,40 +380,21 @@ case class TransformWithStateExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, - stateSchemaVersion: Int): Array[String] = { + stateSchemaVersion: Int): List[StateSchemaValidationResult] = { assert(stateSchemaVersion >= 3) val newColumnFamilySchemas = getColFamilySchemas() - val schemaFile = new StateSchemaV3File( - hadoopConf, stateSchemaDirPath(StateStoreId.DEFAULT_STORE_NAME).toString) - // TODO: [SPARK-48849] Read the schema path from the OperatorStateMetadata file - // and validate it with the new schema - - // Write the new schema to the schema file - val schemaPath = schemaFile.addWithUUID(batchId, newColumnFamilySchemas.values.toList) - Array(schemaPath.toString) + val stateSchemaDir = stateSchemaDirPath() + val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}") + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion, + schemaFilePath = Some(stateSchemaFilePath))) } - private def validateSchemas( - oldSchemas: List[ColumnFamilySchema], - newSchemas: Map[String, ColumnFamilySchema]): Unit = { - oldSchemas.foreach { case oldSchema: ColumnFamilySchemaV1 => - newSchemas.get(oldSchema.columnFamilyName).foreach { - case newSchema: ColumnFamilySchemaV1 => - StateSchemaCompatibilityChecker.check( - (oldSchema.keySchema, oldSchema.valueSchema), - (newSchema.keySchema, newSchema.valueSchema), - ignoreValueSchema = false - ) - } - } - } - - private def stateSchemaDirPath(storeName: String): Path = { - assert(storeName == StateStoreId.DEFAULT_STORE_NAME) - def stateInfo = getStateInfo + private def stateSchemaDirPath(): Path = { + val storeName = StateStoreId.DEFAULT_STORE_NAME val stateCheckpointPath = new Path(getStateInfo.checkpointLocation, - s"${stateInfo.operatorId.toString}") + s"${getStateInfo.operatorId.toString}") val storeNamePath = new Path(stateCheckpointPath, storeName) new Path(new Path(storeNamePath, "_metadata"), "schema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 0a8021ab3de2..f6737307fad1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution.streaming.state import java.io.StringReader +import scala.util.Try + import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} import org.json4s.DefaultFormats -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods -import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.sql.execution.streaming.MetadataVersionUtil import org.apache.spark.sql.types.StructType @@ -33,106 +32,104 @@ import org.apache.spark.util.Utils /** * Helper classes for reading/writing state schema. */ -sealed trait ColumnFamilySchema extends Serializable { - def jsonValue: JValue - - def json: String - - def columnFamilyName: String -} - -case class ColumnFamilySchemaV1( - columnFamilyName: String, - keySchema: StructType, - valueSchema: StructType, - keyStateEncoderSpec: KeyStateEncoderSpec, - userKeyEncoder: Option[StructType] = None) extends ColumnFamilySchema { - def jsonValue: JValue = { - ("columnFamilyName" -> JString(columnFamilyName)) ~ - ("keySchema" -> JString(keySchema.json)) ~ - ("valueSchema" -> JString(valueSchema.json)) ~ - ("keyStateEncoderSpec" -> keyStateEncoderSpec.jsonValue) ~ - ("userKeyEncoder" -> userKeyEncoder.map(s => JString(s.json)).getOrElse(JNothing)) - } - - def json: String = { - compact(render(jsonValue)) - } -} - -object ColumnFamilySchemaV1 { - - /** - * Create a ColumnFamilySchemaV1 object from the Json string - * This function is to read the StateSchemaV3 file - */ - def fromJson(json: String): ColumnFamilySchema = { - implicit val formats: DefaultFormats.type = DefaultFormats - val colFamilyMap = JsonMethods.parse(json).extract[Map[String, Any]] - assert(colFamilyMap.isInstanceOf[Map[_, _]], - s"Expected Map but got ${colFamilyMap.getClass}") - val keySchema = StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]) - val valueSchema = StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]) - ColumnFamilySchemaV1( - colFamilyMap("columnFamilyName").asInstanceOf[String], - keySchema, - valueSchema, - KeyStateEncoderSpec.fromJson(keySchema, colFamilyMap("keyStateEncoderSpec") - .asInstanceOf[Map[String, Any]]), - colFamilyMap.get("userKeyEncoder").map(_.asInstanceOf[String]).map(StructType.fromString) - ) - } -} +case class StateSchemaFormatV3( + stateStoreColFamilySchema: List[StateStoreColFamilySchema] +) object SchemaHelper { sealed trait SchemaReader { - def read(inputStream: FSDataInputStream): (StructType, StructType) + def version: Int + + def read(inputStream: FSDataInputStream): List[StateStoreColFamilySchema] + + protected def readJsonSchema(inputStream: FSDataInputStream): String = { + val buf = new StringBuilder + val numChunks = inputStream.readInt() + (0 until numChunks).foreach(_ => buf.append(inputStream.readUTF())) + buf.toString() + } } object SchemaReader { def createSchemaReader(versionStr: String): SchemaReader = { val version = MetadataVersionUtil.validateVersion(versionStr, - StateSchemaCompatibilityChecker.VERSION) + 3) version match { case 1 => new SchemaV1Reader case 2 => new SchemaV2Reader + case 3 => new SchemaV3Reader } } } class SchemaV1Reader extends SchemaReader { - def read(inputStream: FSDataInputStream): (StructType, StructType) = { + override def version: Int = 1 + + override def read(inputStream: FSDataInputStream): List[StateStoreColFamilySchema] = { val keySchemaStr = inputStream.readUTF() val valueSchemaStr = inputStream.readUTF() - (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + StructType.fromString(keySchemaStr), + StructType.fromString(valueSchemaStr))) } } class SchemaV2Reader extends SchemaReader { - def read(inputStream: FSDataInputStream): (StructType, StructType) = { - val buf = new StringBuilder - val numKeyChunks = inputStream.readInt() - (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) - val keySchemaStr = buf.toString() - - buf.clear() - val numValueChunks = inputStream.readInt() - (0 until numValueChunks).foreach(_ => buf.append(inputStream.readUTF())) - val valueSchemaStr = buf.toString() - (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + override def version: Int = 2 + + override def read(inputStream: FSDataInputStream): List[StateStoreColFamilySchema] = { + val keySchemaStr = readJsonSchema(inputStream) + val valueSchemaStr = readJsonSchema(inputStream) + + List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + StructType.fromString(keySchemaStr), + StructType.fromString(valueSchemaStr))) + } + } + + class SchemaV3Reader extends SchemaReader { + override def version: Int = 3 + + override def read(inputStream: FSDataInputStream): List[StateStoreColFamilySchema] = { + implicit val formats: DefaultFormats.type = DefaultFormats + val numEntries = inputStream.readInt() + (0 until numEntries).map { _ => + // read the col family name and the key and value schema + val colFamilyName = inputStream.readUTF() + val keySchemaStr = readJsonSchema(inputStream) + val valueSchemaStr = readJsonSchema(inputStream) + val keySchema = StructType.fromString(keySchemaStr) + + // use the key schema to also populate the encoder spec + val keyEncoderSpecStr = readJsonSchema(inputStream) + val colFamilyMap = JsonMethods.parse(keyEncoderSpecStr).extract[Map[String, Any]] + val encoderSpec = KeyStateEncoderSpec.fromJson(keySchema, colFamilyMap) + + // read the user key encoder spec if provided + val userKeyEncoderSchemaStr = readJsonSchema(inputStream) + val userKeyEncoderSchema = Try(StructType.fromString(userKeyEncoderSchemaStr)).toOption + + StateStoreColFamilySchema(colFamilyName, + keySchema, + StructType.fromString(valueSchemaStr), + Some(encoderSpec), + userKeyEncoderSchema) + }.toList } } trait SchemaWriter { - val version: Int + // 2^16 - 1 bytes + final val MAX_UTF_CHUNK_SIZE = 65535 + + def version: Int final def write( - keySchema: StructType, - valueSchema: StructType, + stateStoreColFamilySchema: List[StateStoreColFamilySchema], outputStream: FSDataOutputStream): Unit = { writeVersion(outputStream) - writeSchema(keySchema, valueSchema, outputStream) + writeSchema(stateStoreColFamilySchema, outputStream) } private def writeVersion(outputStream: FSDataOutputStream): Unit = { @@ -140,9 +137,24 @@ object SchemaHelper { } protected def writeSchema( - keySchema: StructType, - valueSchema: StructType, + stateStoreColFamilySchema: List[StateStoreColFamilySchema], outputStream: FSDataOutputStream): Unit + + protected def writeJsonSchema( + outputStream: FSDataOutputStream, + jsonSchema: String): Unit = { + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // So a key as well as a value consist of multiple chunks in schema version 2 and above. + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + val numChunks = (jsonSchema.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val stringReader = new StringReader(jsonSchema) + outputStream.writeInt(numChunks) + (0 until numChunks).foreach { _ => + val numRead = stringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + } } object SchemaWriter { @@ -150,53 +162,60 @@ object SchemaHelper { version match { case 1 if Utils.isTesting => new SchemaV1Writer case 2 => new SchemaV2Writer + case 3 => new SchemaV3Writer } } } class SchemaV1Writer extends SchemaWriter { - val version: Int = 1 + override def version: Int = 1 def writeSchema( - keySchema: StructType, - valueSchema: StructType, + stateStoreColFamilySchema: List[StateStoreColFamilySchema], outputStream: FSDataOutputStream): Unit = { - outputStream.writeUTF(keySchema.json) - outputStream.writeUTF(valueSchema.json) + assert(stateStoreColFamilySchema.length == 1) + val stateSchema = stateStoreColFamilySchema.head + outputStream.writeUTF(stateSchema.keySchema.json) + outputStream.writeUTF(stateSchema.valueSchema.json) } } class SchemaV2Writer extends SchemaWriter { - val version: Int = 2 - - // 2^16 - 1 bytes - final val MAX_UTF_CHUNK_SIZE = 65535 + override def version: Int = 2 def writeSchema( - keySchema: StructType, - valueSchema: StructType, + stateStoreColFamilySchema: List[StateStoreColFamilySchema], outputStream: FSDataOutputStream): Unit = { - val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + assert(stateStoreColFamilySchema.length == 1) + val stateSchema = stateStoreColFamilySchema.head - // DataOutputStream.writeUTF can't write a string at once - // if the size exceeds 65535 (2^16 - 1) bytes. - // So a key as well as a value consist of multiple chunks in schema version 2. - val keySchemaJson = keySchema.json - val numKeyChunks = (keySchemaJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 - val keyStringReader = new StringReader(keySchemaJson) - outputStream.writeInt(numKeyChunks) - (0 until numKeyChunks).foreach { _ => - val numRead = keyStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) - outputStream.writeUTF(new String(buf, 0, numRead)) - } + writeJsonSchema(outputStream, stateSchema.keySchema.json) + writeJsonSchema(outputStream, stateSchema.valueSchema.json) + } + } - val valueSchemaJson = valueSchema.json - val numValueChunks = (valueSchemaJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 - val valueStringReader = new StringReader(valueSchemaJson) - outputStream.writeInt(numValueChunks) - (0 until numValueChunks).foreach { _ => - val numRead = valueStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) - outputStream.writeUTF(new String(buf, 0, numRead)) + class SchemaV3Writer extends SchemaWriter { + override def version: Int = 3 + + private val emptyJsonStr = """{ }""" + + def writeSchema( + stateStoreColFamilySchema: List[StateStoreColFamilySchema], + outputStream: FSDataOutputStream): Unit = { + outputStream.writeInt(stateStoreColFamilySchema.size) + stateStoreColFamilySchema.foreach { colFamilySchema => + assert(colFamilySchema.keyStateEncoderSpec.isDefined) + outputStream.writeUTF(colFamilySchema.colFamilyName) + writeJsonSchema(outputStream, colFamilySchema.keySchema.json) + writeJsonSchema(outputStream, colFamilySchema.valueSchema.json) + writeJsonSchema(outputStream, colFamilySchema.keyStateEncoderSpec.get.json) + // write user key encoder schema if provided and empty json otherwise + val userKeyEncoderStr = if (colFamilySchema.userKeyEncoderSchema.isDefined) { + colFamilySchema.userKeyEncoderSchema.get.json + } else { + emptyJsonStr + } + writeJsonSchema(outputStream, userKeyEncoderStr) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 8aabc0846fe6..3230098c74cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -30,22 +30,45 @@ import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.{DataType, StructType} +// Result returned after validating the schema of the state store for schema changes +case class StateSchemaValidationResult( + evolvedSchema: Boolean, + schemaPath: String +) + +// Used to represent the schema of a column family in the state store +case class StateStoreColFamilySchema( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: Option[KeyStateEncoderSpec] = None, + userKeyEncoderSchema: Option[StructType] = None +) + class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, - hadoopConf: Configuration) extends Logging { + hadoopConf: Configuration, + schemaFilePath: Option[Path] = None) extends Logging { - private val storeCpLocation = providerId.storeId.storeCheckpointLocation() - private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) - private val schemaFileLocation = schemaFile(storeCpLocation) - private val schemaWriter = - SchemaWriter.createSchemaWriter(StateSchemaCompatibilityChecker.VERSION) + private val schemaFileLocation = if (schemaFilePath.isEmpty) { + val storeCpLocation = providerId.storeId.storeCheckpointLocation() + schemaFile(storeCpLocation) + } else { + schemaFilePath.get + } + + private val fm = CheckpointFileManager.create(schemaFileLocation, hadoopConf) fm.mkdirs(schemaFileLocation.getParent) - def readSchemaFile(): (StructType, StructType) = { + def readSchemaFile(): List[StateStoreColFamilySchema] = { val inStream = fm.open(schemaFileLocation) try { val versionStr = inStream.readUTF() + // Ensure that version 3 format has schema file path provided explicitly + if (versionStr == "v3" && schemaFilePath.isEmpty) { + throw new IllegalStateException("Schema file path is required for schema version 3") + } val schemaReader = SchemaReader.createSchemaReader(versionStr) schemaReader.read(inStream) } catch { @@ -58,30 +81,38 @@ class StateSchemaCompatibilityChecker( } /** - * Function to read and return the existing key and value schema from the schema file, if it - * exists - * @return - Option of (keySchema, valueSchema) if the schema file exists, None otherwise + * Function to read and return the list of existing state store column family schemas from the + * schema file, if it exists + * @return - List of state store column family schemas if the schema file exists and empty l + * otherwise */ - private def getExistingKeyAndValueSchema(): Option[(StructType, StructType)] = { + private def getExistingKeyAndValueSchema(): List[StateStoreColFamilySchema] = { if (fm.exists(schemaFileLocation)) { - Some(readSchemaFile()) + readSchemaFile() } else { - None + List.empty } } - private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { - createSchemaFile(keySchema, valueSchema, schemaWriter) + private def createSchemaFile( + stateStoreColFamilySchema: List[StateStoreColFamilySchema], + stateSchemaVersion: Int): Unit = { + // Ensure that schema file path is passed explicitly for schema version 3 + if (stateSchemaVersion == 3 && schemaFilePath.isEmpty) { + throw new IllegalStateException("Schema file path is required for schema version 3") + } + + val schemaWriter = SchemaWriter.createSchemaWriter(stateSchemaVersion) + createSchemaFile(stateStoreColFamilySchema, schemaWriter) } // Visible for testing private[sql] def createSchemaFile( - keySchema: StructType, - valueSchema: StructType, + stateStoreColFamilySchema: List[StateStoreColFamilySchema], schemaWriter: SchemaWriter): Unit = { val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) try { - schemaWriter.write(keySchema, valueSchema, outStream) + schemaWriter.write(stateStoreColFamilySchema, outStream) outStream.close() } catch { case e: Throwable => @@ -91,27 +122,8 @@ class StateSchemaCompatibilityChecker( } } - def validateAndMaybeEvolveStateSchema( - newKeySchema: StructType, - newValueSchema: StructType, - ignoreValueSchema: Boolean): Unit = { - val existingSchema = getExistingKeyAndValueSchema() - if (existingSchema.isEmpty) { - // write the schema file if it doesn't exist - createSchemaFile(newKeySchema, newValueSchema) - } else { - // validate if the new schema is compatible with the existing schema - StateSchemaCompatibilityChecker. - check(existingSchema.get, (newKeySchema, newValueSchema), ignoreValueSchema) - } - } - - private def schemaFile(storeCpLocation: Path): Path = - new Path(new Path(storeCpLocation, "_metadata"), "schema") -} - -object StateSchemaCompatibilityChecker extends Logging { - val VERSION = 2 + private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = + DataType.equalsIgnoreNameAndCompatibleNullability(schema, storedSchema) /** * Function to check if new state store schema is compatible with the existing schema. @@ -119,12 +131,13 @@ object StateSchemaCompatibilityChecker extends Logging { * @param newSchema - new state schema * @param ignoreValueSchema - whether to ignore value schema or not */ - def check( - oldSchema: (StructType, StructType), - newSchema: (StructType, StructType), + private def check( + oldSchema: StateStoreColFamilySchema, + newSchema: StateStoreColFamilySchema, ignoreValueSchema: Boolean) : Unit = { - val (storedKeySchema, storedValueSchema) = oldSchema - val (keySchema, valueSchema) = newSchema + val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, + oldSchema.valueSchema) + val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema) if (storedKeySchema.equals(keySchema) && (ignoreValueSchema || storedValueSchema.equals(valueSchema))) { @@ -140,9 +153,39 @@ object StateSchemaCompatibilityChecker extends Logging { } } - private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = - DataType.equalsIgnoreNameAndCompatibleNullability(schema, storedSchema) + /** + * Function to validate the new state store schema and evolve the schema if required. + * @param newStateSchema - proposed new state store schema by the operator + * @param ignoreValueSchema - whether to ignore value schema compatibility checks or not + * @param stateSchemaVersion - version of the state schema to be used + * @return - true if the schema has evolved, false otherwise + */ + def validateAndMaybeEvolveStateSchema( + newStateSchema: List[StateStoreColFamilySchema], + ignoreValueSchema: Boolean, + stateSchemaVersion: Int): Boolean = { + val existingStateSchemaList = getExistingKeyAndValueSchema().sortBy(_.colFamilyName) + val newStateSchemaList = newStateSchema.sortBy(_.colFamilyName) + + if (existingStateSchemaList.isEmpty) { + // write the schema file if it doesn't exist + createSchemaFile(newStateSchemaList, stateSchemaVersion) + true + } else { + // validate if the new schema is compatible with the existing schema + existingStateSchemaList.lazyZip(newStateSchemaList).foreach { + case (existingStateSchema, newStateSchema) => + check(existingStateSchema, newStateSchema, ignoreValueSchema) + } + false + } + } + private def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") +} + +object StateSchemaCompatibilityChecker { private def disallowBinaryInequalityColumn(schema: StructType): Unit = { if (!UnsafeRowUtils.isBinaryStable(schema)) { throw new SparkUnsupportedOperationException( @@ -160,20 +203,23 @@ object StateSchemaCompatibilityChecker extends Logging { * * @param stateInfo - StatefulOperatorStateInfo containing the state store information * @param hadoopConf - Hadoop configuration - * @param newKeySchema - New key schema - * @param newValueSchema - New value schema + * @param newStateSchema - Array of new schema for state store column families * @param sessionState - session state used to retrieve session config + * @param stateSchemaVersion - version of the state schema to be used * @param extraOptions - any extra options to be passed for StateStoreConf creation * @param storeName - optional state store name + * @param schemaFilePath - optional schema file path + * @return - StateSchemaValidationResult containing the result of the schema validation */ def validateAndMaybeEvolveStateSchema( stateInfo: StatefulOperatorStateInfo, hadoopConf: Configuration, - newKeySchema: StructType, - newValueSchema: StructType, + newStateSchema: List[StateStoreColFamilySchema], sessionState: SessionState, + stateSchemaVersion: Int, extraOptions: Map[String, String] = Map.empty, - storeName: String = StateStoreId.DEFAULT_STORE_NAME): Array[String] = { + storeName: String = StateStoreId.DEFAULT_STORE_NAME, + schemaFilePath: Option[Path] = None): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means // in some collation we no longer be able to just compare the binary format of two // UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically" @@ -183,25 +229,38 @@ object StateSchemaCompatibilityChecker extends Logging { // We need to disallow using binary inequality column in the key schema, before we // could support this in majority of state store providers (or high-level of state // store.) - disallowBinaryInequalityColumn(newKeySchema) + newStateSchema.foreach { schema => + disallowBinaryInequalityColumn(schema.keySchema) + } val storeConf = new StateStoreConf(sessionState.conf, extraOptions) val providerId = StateStoreProviderId(StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, 0, storeName), stateInfo.queryRunId) - val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf) + val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, + schemaFilePath = schemaFilePath) // regardless of configuration, we check compatibility to at least write schema file // if necessary // if the format validation for value schema is disabled, we also disable the schema // compatibility checker for value schema as well. + + // Currently - schema evolution can happen only once per query. Basically for the initial batch + // there is no previous schema. So we classify that case under schema evolution. In the future, + // newer stateSchemaVersions will support evolution through the lifetime of the query as well. + var evolvedSchema = false val result = Try( - checker.validateAndMaybeEvolveStateSchema(newKeySchema, newValueSchema, - ignoreValueSchema = !storeConf.formatValidationCheckValue) - ).toEither.fold(Some(_), _ => None) + checker.validateAndMaybeEvolveStateSchema(newStateSchema, + ignoreValueSchema = !storeConf.formatValidationCheckValue, + stateSchemaVersion = stateSchemaVersion) + ).toEither.fold(Some(_), + hasEvolvedSchema => { + evolvedSchema = hasEvolvedSchema + None + }) // if schema validation is enabled and an exception is thrown, we re-throw it and fail the query if (storeConf.stateSchemaCheckEnabled && result.isDefined) { throw result.get } - Array(checker.schemaFileLocation.toString) + StateSchemaValidationResult(evolvedSchema, checker.schemaFileLocation.toString) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala deleted file mode 100644 index 482e802b7d87..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import java.io.{InputStream, OutputStream} -import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID - -import scala.io.{Source => IOSource} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path - -import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVersion - -/** - * The StateSchemaV3File is used to write the schema of multiple column families. - * Right now, this is primarily used for the TransformWithState operator, which supports - * multiple column families to keep the data for multiple state variables. - * We only expect ColumnFamilySchemaV1 to be written and read from this file. - * @param hadoopConf Hadoop configuration that is used to read / write metadata files. - * @param path Path to the directory that will be used for writing metadata. - */ -class StateSchemaV3File( - hadoopConf: Configuration, - path: String) { - - val metadataPath = new Path(path) - - protected val fileManager: CheckpointFileManager = - CheckpointFileManager.create(metadataPath, hadoopConf) - - if (!fileManager.exists(metadataPath)) { - fileManager.mkdirs(metadataPath) - } - - private[sql] def deserialize(in: InputStream): List[ColumnFamilySchema] = { - val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() - - if (!lines.hasNext) { - throw new IllegalStateException("Incomplete log file in the offset commit log") - } - - val version = lines.next().trim - validateVersion(version, StateSchemaV3File.VERSION) - - lines.map(ColumnFamilySchemaV1.fromJson).toList - } - - private[sql] def serialize(schemas: List[ColumnFamilySchema], out: OutputStream): Unit = { - out.write(s"v${StateSchemaV3File.VERSION}".getBytes(UTF_8)) - out.write('\n') - out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8)) - } - - def addWithUUID(batchId: Long, metadata: List[ColumnFamilySchema]): Path = { - val schemaFilePath = new Path(metadataPath, s"${batchId}_${UUID.randomUUID().toString}") - write(schemaFilePath, out => serialize(metadata, out)) - schemaFilePath - } - - def getWithPath(schemaFilePath: Path): List[ColumnFamilySchema] = { - deserialize(fileManager.open(schemaFilePath)) - } - - protected def write( - batchMetadataFile: Path, - fn: OutputStream => Unit): Unit = { - val output = fileManager.createAtomic(batchMetadataFile, overwriteIfPossible = false) - try { - fn(output) - output.close() - } catch { - case e: Throwable => - output.cancel() - throw e - } - } -} - -object StateSchemaV3File { - val VERSION = 3 -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 94d976b568a5..14f67460763b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -78,7 +78,8 @@ trait StatefulOperator extends SparkPlan { // Returns the schema file path for operators that write this to the metadata file, // otherwise None def validateAndMaybeEvolveStateSchema( - hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): Array[String] + hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): + List[StateSchemaValidationResult] } /** @@ -435,9 +436,11 @@ case class StateStoreRestoreExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keyExpressions.toStructType, stateManager.getStateValueSchema, session.sessionState) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keyExpressions.toStructType, stateManager.getStateValueSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, + hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { @@ -503,10 +506,13 @@ case class StateStoreSaveExec( keyExpressions, child.output, stateFormatVersion) override def validateAndMaybeEvolveStateSchema( - hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keyExpressions.toStructType, stateManager.getStateValueSchema, session.sessionState) + hadoopConf: Configuration, + batchId: Long, + stateSchemaVersion: Int): List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keyExpressions.toStructType, stateManager.getStateValueSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, + hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { @@ -715,9 +721,11 @@ case class SessionWindowStateStoreRestoreExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - stateManager.getStateKeySchema, stateManager.getStateValueSchema, session.sessionState) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + stateManager.getStateKeySchema, stateManager.getStateValueSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { @@ -804,9 +812,11 @@ case class SessionWindowStateStoreSaveExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - stateManager.getStateKeySchema, stateManager.getStateValueSchema, session.sessionState) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + stateManager.getStateKeySchema, stateManager.getStateValueSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { @@ -1119,9 +1129,12 @@ case class StreamingDeduplicateExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keyExpressions.toStructType, schemaForValueRow, session.sessionState, extraOptionOnStateStore) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keyExpressions.toStructType, schemaForValueRow)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion, + extraOptions = extraOptionOnStateStore)) } } @@ -1196,9 +1209,12 @@ case class StreamingDeduplicateWithinWatermarkExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keyExpressions.toStructType, schemaForValueRow, session.sessionState, extraOptionOnStateStore) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keyExpressions.toStructType, schemaForValueRow)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion, + extraOptions = extraOptionOnStateStore)) } override protected def withNewChildInternal( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala index 7b3d393ec75d..0be2450c0ed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateSchemaCompatibilityChecker, StateStoreOps} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreOps} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, NextIterator} @@ -49,9 +49,11 @@ case class StreamingGlobalLimitExec( override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): - Array[String] = { - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - keySchema, valueSchema, session.sessionState) + List[StateSchemaValidationResult] = { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keySchema, valueSchema)) + List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newStateSchema, session.sessionState, stateSchemaVersion)) } override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index feab7a5fa3b0..f5a5d1277d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.util.{Random, Try} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo @@ -33,6 +34,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { private val hadoopConf: Configuration = new Configuration() private val opId = Random.nextInt(100000) + private val batchId = 0 private val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA private val structSchema = new StructType() @@ -235,14 +237,52 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val queryId = UUID.randomUUID() val providerId = StateStoreProviderId( StateStoreId(dir, opId, partitionId), queryId) + val storeColFamilySchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + keySchema, valueSchema)) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf) - checker.createSchemaFile(keySchema, valueSchema, + checker.createSchemaFile(storeColFamilySchema, SchemaHelper.SchemaWriter.createSchemaWriter(1)) - val (resultKeySchema, resultValueSchema) = checker.readSchemaFile() + val stateSchema = checker.readSchemaFile().head + val (resultKeySchema, resultValueSchema) = (stateSchema.keySchema, stateSchema.valueSchema) assert((resultKeySchema, resultValueSchema) === (keySchema, valueSchema)) } + Seq("NoPrefixKeyStateEncoderSpec", "PrefixKeyScanStateEncoderSpec", + "RangeKeyScanStateEncoderSpec").foreach { encoderSpecStr => + test(s"checking for compatibility with schema version 3 with encoderSpec=$encoderSpecStr") { + val stateSchemaVersion = 3 + val dir = newDir() + val queryId = UUID.randomUUID() + val providerId = StateStoreProviderId( + StateStoreId(dir, opId, partitionId), queryId) + val runId = UUID.randomUUID() + val stateInfo = StatefulOperatorStateInfo(dir, runId, opId, 0, 200) + val storeColFamilySchema = List( + StateStoreColFamilySchema("test1", keySchema, valueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema, + encoderSpecStr)), + StateStoreColFamilySchema("test2", longKeySchema, longValueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, longKeySchema, + encoderSpecStr)), + StateStoreColFamilySchema("test3", keySchema65535Bytes, valueSchema65535Bytes, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema65535Bytes)), + StateStoreColFamilySchema("test4", keySchema, valueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, keySchema, + encoderSpecStr), + userKeyEncoderSchema = Some(structSchema))) + val stateSchemaDir = stateSchemaDirPath(stateInfo) + val schemaFilePath = Some(new Path(stateSchemaDir, + s"${batchId}_${UUID.randomUUID().toString}")) + val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, + schemaFilePath = schemaFilePath) + checker.createSchemaFile(storeColFamilySchema, + SchemaHelper.SchemaWriter.createSchemaWriter(stateSchemaVersion)) + val stateSchema = checker.readSchemaFile() + assert(stateSchema.sortBy(_.colFamilyName) === storeColFamilySchema.sortBy(_.colFamilyName)) + } + } + test("SPARK-39650: ignore value schema on compatibility check") { val typeChangedValueSchema = StructType(valueSchema.map(_.copy(dataType = TimestampType))) verifySuccess(keySchema, valueSchema, keySchema, typeChangedValueSchema, @@ -289,6 +329,36 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { StructType(newFields) } + private def stateSchemaDirPath(stateInfo: StatefulOperatorStateInfo): Path = { + val storeName = StateStoreId.DEFAULT_STORE_NAME + val stateCheckpointPath = + new Path(stateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + } + + private def getKeyStateEncoderSpec( + stateSchemaVersion: Int, + keySchema: StructType, + encoderSpec: String = "NoPrefixKeyStateEncoderSpec"): Option[KeyStateEncoderSpec] = { + if (stateSchemaVersion == 3) { + encoderSpec match { + case "NoPrefixKeyStateEncoderSpec" => + Some(NoPrefixKeyStateEncoderSpec(keySchema)) + case "PrefixKeyScanStateEncoderSpec" => + Some(PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey = 1)) + case "RangeKeyScanStateEncoderSpec" => + Some(RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals = Seq(0))) + case _ => + None + } + } else { + None + } + } + private def verifyException( oldKeySchema: StructType, oldValueSchema: StructType, @@ -303,33 +373,50 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) - val result = Try( - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, - oldKeySchema, oldValueSchema, spark.sessionState, extraOptions) - ).toEither.fold(Some(_), _ => None) + Seq(2, 3).foreach { stateSchemaVersion => + val schemaFilePath = if (stateSchemaVersion == 3) { + val stateSchemaDir = stateSchemaDirPath(stateInfo) + Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) + } else { + None + } - val ex = if (result.isDefined) { - result.get.asInstanceOf[SparkUnsupportedOperationException] - } else { - intercept[SparkUnsupportedOperationException] { + val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + oldKeySchema, oldValueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) + val result = Try( StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, - newKeySchema, newValueSchema, spark.sessionState, extraOptions) + oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, + schemaFilePath = schemaFilePath, extraOptions = extraOptions) + ).toEither.fold(Some(_), _ => None) + + val ex = if (result.isDefined) { + result.get.asInstanceOf[SparkUnsupportedOperationException] + } else { + intercept[SparkUnsupportedOperationException] { + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + newKeySchema, newValueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) + StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, + newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, + schemaFilePath = schemaFilePath, extraOptions = extraOptions) + } } - } - // collation checks are also performed in this path. so we need to check for them explicitly. - if (keyCollationChecks) { - assert(ex.getMessage.contains("Binary inequality column is not supported")) - assert(ex.getErrorClass === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") - } else { - if (ignoreValueSchema) { - // if value schema is ignored, the mismatch has to be on the key schema - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") + // collation checks are also performed in this path. so we need to check for them explicitly. + if (keyCollationChecks) { + assert(ex.getMessage.contains("Binary inequality column is not supported")) + assert(ex.getErrorClass === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") } else { - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || - ex.getErrorClass === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") + if (ignoreValueSchema) { + // if value schema is ignored, the mismatch has to be on the key schema + assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") + } else { + assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || + ex.getErrorClass === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") + } + assert(ex.getMessage.contains("does not match existing")) } - assert(ex.getMessage.contains("does not match existing")) } } @@ -346,10 +433,27 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, - oldKeySchema, oldValueSchema, spark.sessionState, extraOptions) + Seq(2, 3).foreach { stateSchemaVersion => + val schemaFilePath = if (stateSchemaVersion == 3) { + val stateSchemaDir = stateSchemaDirPath(stateInfo) + Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) + } else { + None + } + + val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + oldKeySchema, oldValueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) + StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, + oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, + schemaFilePath = schemaFilePath, extraOptions = extraOptions) - StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, - newKeySchema, newValueSchema, spark.sessionState, extraOptions) + val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + newKeySchema, newValueSchema, + keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) + StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, + newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, + schemaFilePath = schemaFilePath, extraOptions = extraOptions) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index c649dad76092..2e65748cb467 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMultipleColumnFamiliesNotSupportedException, TestClass} +import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -805,82 +805,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify StateSchemaV3 serialization and deserialization" + - " works with one batch") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val testKeyEncoder = Encoders.STRING - val testValueEncoder = Encoders.scalaInt - val schema = List(ColumnFamilySchemaV1( - "countState", - testKeyEncoder.schema, - testValueEncoder.schema, - NoPrefixKeyStateEncoderSpec(testKeyEncoder.schema), - None - )) - - val schemaFile = new StateSchemaV3File( - spark.sessionState.newHadoopConf(), checkpointDir.getCanonicalPath) - val path = schemaFile.addWithUUID(0, schema) - - assert(schemaFile.getWithPath(path) == schema) - } - } - } - - test("transformWithState - verify StateSchemaV3 serialization and deserialization" + - " works with multiple batches") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val testKeyEncoder = Encoders.STRING - val testValueEncoder = Encoders.scalaInt - val schema0 = List(ColumnFamilySchemaV1( - "countState", - testKeyEncoder.schema, - testValueEncoder.schema, - NoPrefixKeyStateEncoderSpec(testKeyEncoder.schema), - None - )) - - val schema1 = List( - ColumnFamilySchemaV1( - "countState", - testKeyEncoder.schema, - testValueEncoder.schema, - NoPrefixKeyStateEncoderSpec(testKeyEncoder.schema), - None - ), - ColumnFamilySchemaV1( - "mostRecent", - testKeyEncoder.schema, - testValueEncoder.schema, - NoPrefixKeyStateEncoderSpec(testKeyEncoder.schema), - None - ) - ) - - val schemaFile = new StateSchemaV3File(spark.sessionState.newHadoopConf(), - checkpointDir.getCanonicalPath) - val path0 = schemaFile.addWithUUID(0, schema0) - - assert(schemaFile.getWithPath(path0) == schema0) - - // test the case where we are trying to add the schema after - // restarting after a few batches - val path1 = schemaFile.addWithUUID(3, schema1) - val latestSchema = schemaFile.getWithPath(path1) - - assert(latestSchema == schema1) - } - } - } - test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -894,20 +818,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchema = new StructType().add("value", StringType) - val schema0 = ColumnFamilySchemaV1( + val schema0 = StateStoreColFamilySchema( "countState", keySchema, new StructType().add("value", LongType, false), - NoPrefixKeyStateEncoderSpec(keySchema), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema1 = ColumnFamilySchemaV1( + val schema1 = StateStoreColFamilySchema( "listState", keySchema, new StructType() .add("id", LongType, false) .add("name", StringType), - NoPrefixKeyStateEncoderSpec(keySchema), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) @@ -917,11 +841,11 @@ class TransformWithStateSuite extends StateStoreMetricsTest val compositeKeySchema = new StructType() .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) - val schema2 = ColumnFamilySchemaV1( + val schema2 = StateStoreColFamilySchema( "mapState", compositeKeySchema, new StructType().add("value", StringType), - PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1), + Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Option(userKeySchema) ) @@ -937,9 +861,13 @@ class TransformWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a", "b"), CheckNewAnswer(("a", "1"), ("b", "1")), Execute { q => + q.lastProgress.runId val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath - val schemaFile = new StateSchemaV3File(hadoopConf, stateSchemaPath.getName) - val colFamilySeq = schemaFile.deserialize(fm.open(schemaFilePath)) + val providerId = StateStoreProviderId(StateStoreId( + checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) + val checker = new StateSchemaCompatibilityChecker(providerId, + hadoopConf, Some(schemaFilePath)) + val colFamilySeq = checker.readSchemaFile() assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 6e6d4de94701..db5cb027d39d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaV3File, TestClass} +import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types._ @@ -265,24 +265,24 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchema = new StructType().add("value", StringType) - val schema0 = ColumnFamilySchemaV1( + val schema0 = StateStoreColFamilySchema( "valueStateTTL", keySchema, new StructType().add("value", new StructType() .add("value", IntegerType, false)) .add("ttlExpirationMs", LongType), - NoPrefixKeyStateEncoderSpec(keySchema), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema1 = ColumnFamilySchemaV1( + val schema1 = StateStoreColFamilySchema( "valueState", keySchema, new StructType().add("value", IntegerType, false), - NoPrefixKeyStateEncoderSpec(keySchema), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema2 = ColumnFamilySchemaV1( + val schema2 = StateStoreColFamilySchema( "listState", keySchema, new StructType().add("value", @@ -290,7 +290,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { .add("id", LongType, false) .add("name", StringType)) .add("ttlExpirationMs", LongType), - NoPrefixKeyStateEncoderSpec(keySchema), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) @@ -300,14 +300,14 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val compositeKeySchema = new StructType() .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) - val schema3 = ColumnFamilySchemaV1( + val schema3 = StateStoreColFamilySchema( "mapState", compositeKeySchema, new StructType().add("value", new StructType() .add("value", StringType)) .add("ttlExpirationMs", LongType), - PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1), + Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Option(userKeySchema) ) @@ -333,8 +333,11 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { CheckNewAnswer(), Execute { q => val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath - val schemaFile = new StateSchemaV3File(hadoopConf, stateSchemaPath.getName) - val colFamilySeq = schemaFile.deserialize(fm.open(schemaFilePath)) + val providerId = StateStoreProviderId(StateStoreId( + checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) + val checker = new StateSchemaCompatibilityChecker(providerId, + hadoopConf, Some(schemaFilePath)) + val colFamilySeq = checker.readSchemaFile() assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org