This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2b00ae84341 [SPARK-54443][SS] Integrate PartitionKeyExtractor in 
Re-partition reader
d2b00ae84341 is described below

commit d2b00ae843416ba2fd9c70fe0da9e041e9e04658
Author: zifeif2 <[email protected]>
AuthorDate: Fri Jan 9 12:09:08 2026 -0800

    [SPARK-54443][SS] Integrate PartitionKeyExtractor in Re-partition reader
    
    ### What changes were proposed in this pull request?
    
    Integrate the PartitionKeyExtractor introduced in [this 
PR](https://github.com/apache/spark/pull/53355/files) to 
StatePartitionAllColumnFamiliesReader. Before this change, 
StatePartitionAllColumnFamiliesReader returns the entire key value in 
partition_key field, and SchemaUtil will return `keySchema` as the partitionKey 
schema. After this change, StatePartitionAllColumnFamiliesReader will return 
the actual partition key, and SchemaUtil returns the actual partitionKey schema
    
    ### Why are the changes needed?
    
    When creating a StatePartitionAllColumnFamiliesReader, we need to pass 
along the stateFormatVersion and operator name.
    In SchemaUtil, we will create a `getExtractor` helper function. It's used 
when getSourceSchema is called (for default column family), as well as when 
StatePartitionAllColumnFamiliesReader is initialized, as the reader will use 
the extractor to get partitionKey for different column families in `iterator`
    We also added checks of partitionKey in reader suite
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    See updated StatePartitionAllColumnFamiliesSuite. We have hard-coded 
function that extract partition key for different column families from 
normalDF, then we'll compare the extracted partition key against the partition 
key read from bytesDF
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes. claude-4.5-opus
    
    Closes #53459 from zifeif2/integrate-key-extraction.
    
    Authored-by: zifeif2 <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   7 ++
 .../datasources/v2/state/StateDataSource.scala     |  97 ++++++++++++------
 .../v2/state/StatePartitionReader.scala            |  71 ++++++++++---
 .../datasources/v2/state/utils/SchemaUtil.scala    |  48 +++++----
 .../StateStoreColumnFamilySchemaUtils.scala        |  59 ++++++++++-
 .../TransformWithStateVariableUtils.scala          |   5 +
 .../transformwithstate/timers/TimerStateImpl.scala |   6 ++
 .../streaming/state/StateStoreErrors.scala         |  10 ++
 ...tatePartitionAllColumnFamiliesReaderSuite.scala | 112 ++++++++++++++-------
 9 files changed, 314 insertions(+), 101 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 1979fbffe361..4a0bfa704b2d 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5771,6 +5771,13 @@
     ],
     "sqlState" : "XXKST"
   },
+  "STATE_STORE_UNKNOWN_INTERNAL_COLUMN_FAMILY" : {
+    "message" : [
+      "Unknown internal column family: <colFamilyName>.",
+      "This internal column family is not recognized by the 
StateStoreColumnFamilySchemaUtils."
+    ],
+    "sqlState" : "XXKST"
+  },
   "STATE_STORE_UNSUPPORTED_OPERATION" : {
     "message" : [
       "<operationType> operation not supported with <entity>"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 897bb8339197..40fba5a90cbd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -44,6 +44,7 @@ import 
org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint
 import 
org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, 
KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, 
StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, 
StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, 
StateStoreProviderId}
 import 
org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
 import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.DataSourceRegister
 import org.apache.spark.sql.streaming.TimeMode
 import org.apache.spark.sql.types.StructType
@@ -75,9 +76,8 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
         sourceOptions.resolvedCpLocation,
         stateConf.providerClass)
     }
-    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(
-      sourceOptions)
 
+    val stateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions)
     // The key state encoder spec should be available for all operators except 
stream-stream joins
     val keyStateEncoderSpec = if 
(stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
       stateStoreReaderInfo.keyStateEncoderSpecOpt.get
@@ -91,15 +91,14 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
       stateStoreReaderInfo.stateSchemaProviderOpt,
       stateStoreReaderInfo.joinColFamilyOpt,
-      Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo))
+      stateStoreReaderInfo.allColumnFamiliesReaderInfo)
   }
 
   override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
     val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
       StateSourceOptions.apply(session, hadoopConf, options))
 
-    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(
-      sourceOptions)
+    val stateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions)
     val oldSchemaFilePaths = 
StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
 
     val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
@@ -123,7 +122,9 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       SchemaUtil.getSourceSchema(sourceOptions, keySchema,
         valueSchema,
         stateStoreReaderInfo.transformWithStateVariableInfoOpt,
-        stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
+        stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
+        stateStoreReaderInfo.allColumnFamiliesReaderInfo.map(_.operatorName),
+        
stateStoreReaderInfo.allColumnFamiliesReaderInfo.flatMap(_.stateFormatVersion))
     } catch {
       case NonFatal(e) =>
         throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
@@ -132,6 +133,29 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
   override def supportsExternalMetadata(): Boolean = false
 
+  /**
+   * Return the state format version for SYMMETRIC_HASH_JOIN operators,
+   * otherwise None for non-join operators.
+   * This currently only supports join operators because this function is only 
used to
+   * create a PartitionKeyExtractor through PartitionKeyExtractorFactory where 
only join operators
+   * require state format version
+   */
+  private def getStateFormatVersion(
+      storeMetadata: Array[StateMetadataTableEntry],
+      checkpointLocation: String,
+      batchId: Long): Option[Int] = {
+    if (storeMetadata.nonEmpty &&
+      storeMetadata.head.operatorName == 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
+      new StreamingQueryCheckpointMetadata(session, 
checkpointLocation).offsetLog
+        .get(batchId)
+        .flatMap(_.metadataOpt)
+        .flatMap(_.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key))
+        .map(_.toInt)
+    } else {
+      None
+    }
+  }
+
   /**
    * Returns true if this is a read-all-column-families request for a 
stream-stream join
    * that uses virtual column families (state format version 3).
@@ -260,8 +284,8 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
     }
   }
 
-  private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
-    StateStoreReaderInfo = {
+  private def getStoreMetadataAndRunChecks(
+      sourceOptions: StateSourceOptions): StateStoreReaderInfo = {
     val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, 
hadoopConf)
     if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
       // Skip runStateVarChecks when reading all column families (for 
repartitioning) because:
@@ -297,28 +321,31 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
           if (sourceOptions.readRegisteredTimers) {
             stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
           }
-          // When reading all column families (for repartitioning), we collect 
all state variable
-          // infos instead of validating a specific stateVarName. This skips 
the normal validation
-          // logic because we're not reading a specific state variable - we're 
reading all of them.
           if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+            // When reading all column families (for repartitioning) for TWS 
operator,
+            // we will just choose a random state as placeholder for default 
column family,
+            // because we need to use matching stateVariableInfo and 
stateStoreColFamilySchemaOpt
+            // to inferSchema (partitionKey in particular) later
+            stateVarName = operatorProperties.stateVariables.head.stateName
             stateVariableInfos = operatorProperties.stateVariables
-          } else {
-            var stateVarInfoList = operatorProperties.stateVariables
-              .filter(stateVar => stateVar.stateName == stateVarName)
-            if (stateVarInfoList.isEmpty &&
+          }
+
+          var stateVarInfoList = operatorProperties.stateVariables
+            .filter(stateVar => stateVar.stateName == stateVarName)
+          if (!TimerStateUtils.isTimerCFName(stateVarName) &&
               
StateStoreColumnFamilySchemaUtils.isTestingInternalColFamily(stateVarName)) {
-              // pass this dummy TWSStateVariableInfo for TWS internal column 
family during testing,
-              // because internalColumns are not register in 
operatorProperties.stateVariables,
-              // thus stateVarInfoList will be empty.
-              stateVarInfoList = List(TransformWithStateVariableInfo(
-                stateVarName, StateVariableType.ValueState, false
-              ))
-            }
-            require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
-              s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
-            val stateVarInfo = stateVarInfoList.head
-            transformWithStateVariableInfoOpt = Some(stateVarInfo)
+            // pass this dummy TWSStateVariableInfo for TWS internal column 
family during testing,
+            // because internal column families are not registered in
+            // operatorProperties.stateVariables, thus stateVarInfoList will 
be empty.
+            stateVarInfoList = List(TransformWithStateVariableInfo(
+              stateVarName, StateVariableType.ValueState, false
+            ))
           }
+          require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
+            s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
+          val stateVarInfo = stateVarInfoList.head
+          transformWithStateVariableInfoOpt = Some(stateVarInfo)
+
           val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
           val stateSchemaMetadata = 
StateSchemaMetadata.createStateSchemaMetadata(
             sourceOptions.stateCheckpointLocation.toString,
@@ -374,13 +401,27 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       }
     }
 
+    val allColFamilyReaderInfoOpt: Option[AllColumnFamiliesReaderInfo] =
+      if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+        assert(storeMetadata.nonEmpty, "storeMetadata shouldn't be empty")
+        val operatorName = storeMetadata.head.operatorName
+        val stateFormatVersion = getStateFormatVersion(
+          storeMetadata,
+          sourceOptions.resolvedCpLocation,
+          sourceOptions.batchId
+        )
+        Some(AllColumnFamiliesReaderInfo(
+          stateStoreColFamilySchemas, stateVariableInfos, operatorName, 
stateFormatVersion))
+      } else {
+        None
+      }
     StateStoreReaderInfo(
       keyStateEncoderSpecOpt,
       stateStoreColFamilySchemaOpt,
       transformWithStateVariableInfoOpt,
       stateSchemaProvider,
       joinColFamilyOpt,
-      AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, 
stateVariableInfos)
+      allColFamilyReaderInfoOpt
     )
   }
 
@@ -769,7 +810,7 @@ case class StateStoreReaderInfo(
     stateSchemaProviderOpt: Option[StateSchemaProvider],
     joinColFamilyOpt: Option[String], // Only used for join op with state 
format v3
     // List of all column family schemas - used when 
internalOnlyReadAllColumnFamilies=true
-    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo
+    allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]
 )
 
 object StateDataSource {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 40c6999f962f..f55822af4ada 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils,
 StatePartitionKeyExtractorFactory}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils,
 StateVariableType, TransformWithStateVariableInfo}
 import org.apache.spark.sql.execution.streaming.state._
@@ -38,7 +39,9 @@ import org.apache.spark.util.{NextIterator, 
SerializableConfiguration}
  */
 case class AllColumnFamiliesReaderInfo(
     colFamilySchemas: Set[StateStoreColFamilySchema] = Set.empty,
-    stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)
+    stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty,
+    operatorName: String,
+    stateFormatVersion: Option[Int] = None)
 
 /**
  * An implementation of [[PartitionReaderFactory]] for State data source. This 
is used to support
@@ -285,6 +288,38 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isTWSOperator(operatorName: String): Boolean = {
+    StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName)
+  }
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    isTWSOperator(operatorName) && colFamilyName == 
StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  // Create extractors for each column family - each column family may have 
different key schema
+  private lazy val cfPartitionKeyExtractors: Map[String, 
StatePartitionKeyExtractor] = {
+
+    stateStoreColFamilySchemas
+      // Filter out default column family for TWS operators because they are 
not in use
+      // and will not have a key extractor
+      .filter(schema => !isDefaultColFamilyInTWS(operatorName, 
schema.colFamilyName))
+      .map { cfSchema =>
+        val stateVariableInfoOpt = stateVariableInfos.find(
+          _.stateName == StateStoreColumnFamilySchemaUtils.getBaseStateName(
+            cfSchema.colFamilyName))
+        val extractor = StatePartitionKeyExtractorFactory.create(
+          operatorName,
+          cfSchema.keySchema,
+          partition.sourceOptions.storeName,
+          cfSchema.colFamilyName,
+          stateFormatVersion,
+          stateVariableInfoOpt)
+        cfSchema.colFamilyName -> extractor
+      }.toMap
+  }
 
   private def isListType(colFamilyName: String): Boolean = {
     SchemaUtil.checkVariableType(
@@ -364,21 +399,27 @@ class StatePartitionAllColumnFamiliesReader(
 
   override lazy val iter: Iterator[InternalRow] = {
     // Iterate all column families and concatenate results
-    stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
-      if (isListType(cfSchema.colFamilyName)) {
-        store.iterator(cfSchema.colFamilyName).flatMap(
-          pair =>
-            store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
-              value =>
-                SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), 
cfSchema.colFamilyName)
-            }
-        )
-      } else {
-        store.iterator(cfSchema.colFamilyName).map { pair =>
-          SchemaUtil.unifyStateRowPairAsRawBytes(
-            (pair.key, pair.value), cfSchema.colFamilyName)
+    stateStoreColFamilySchemas.iterator
+      // Filter out default column family for TWS operators because they are 
not in use
+      // and will not have data
+      .filter(schema => !isDefaultColFamilyInTWS(operatorName, 
schema.colFamilyName))
+      .flatMap { cfSchema =>
+        val extractor = cfPartitionKeyExtractors(cfSchema.colFamilyName)
+        if (isListType(cfSchema.colFamilyName)) {
+          store.iterator(cfSchema.colFamilyName).flatMap(
+            pair =>
+              store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
+                value =>
+                  SchemaUtil.unifyStateRowPairAsRawBytes(
+                    (pair.key, value), cfSchema.colFamilyName, extractor)
+              }
+          )
+        } else {
+          store.iterator(cfSchema.colFamilyName).map { pair =>
+            SchemaUtil.unifyStateRowPairAsRawBytes(
+              (pair.key, pair.value), cfSchema.colFamilyName, extractor)
+          }
         }
-      }
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 44e032f5163a..1210f0605179 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -25,9 +25,10 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import 
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, 
StateSourceOptions}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatePartitionKeyExtractorFactory
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
 TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
-import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, 
StateStoreColFamilySchema, UnsafeRowPair}
+import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, 
StatePartitionKeyExtractor, StateStoreColFamilySchema, UnsafeRowPair}
 import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, 
IntegerType, LongType, MapType, StringType, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
@@ -49,8 +50,26 @@ object SchemaUtil {
       keySchema: StructType,
       valueSchema: StructType,
       transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo],
-      stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): 
StructType = {
-    if (transformWithStateVariableInfoOpt.isDefined) {
+      stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+      operatorName: Option[String],
+      stateFormatVersion: Option[Int] = None): StructType = {
+    if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+      require(stateStoreColFamilySchemaOpt.isDefined)
+      require(operatorName.isDefined)
+      val colFamilyName: String = 
stateStoreColFamilySchemaOpt.get.colFamilyName
+      val extractor = StatePartitionKeyExtractorFactory.create(
+        operatorName.get,
+        keySchema,
+        sourceOptions.storeName,
+        colFamilyName,
+        stateFormatVersion,
+        transformWithStateVariableInfoOpt)
+      new StructType()
+        .add("partition_key", extractor.partitionKeySchema)
+        .add("key_bytes", BinaryType)
+        .add("value_bytes", BinaryType)
+        .add("column_family_name", StringType)
+    } else if (transformWithStateVariableInfoOpt.isDefined) {
       require(stateStoreColFamilySchemaOpt.isDefined)
       generateSchemaForStateVar(transformWithStateVariableInfoOpt.get,
         stateStoreColFamilySchemaOpt.get, sourceOptions)
@@ -61,14 +80,6 @@ object SchemaUtil {
         .add("key", keySchema)
         .add("value", valueSchema)
         .add("partition_id", IntegerType)
-    } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
-      new StructType()
-        // TODO [SPARK-54443]: change keySchema to a more specific type after 
we
-        // can extract partition key from keySchema
-        .add("partition_key", keySchema)
-        .add("key_bytes", BinaryType)
-        .add("value_bytes", BinaryType)
-        .add("column_family_name", StringType)
     } else {
       new StructType()
         .add("key", keySchema)
@@ -87,18 +98,17 @@ object SchemaUtil {
 
   /**
    * Returns an InternalRow representing
-   * 1. partitionKey
+   * 1. partitionKey (extracted using the StatePartitionKeyExtractor)
    * 2. key in bytes
    * 3. value in bytes
    * 4. column family name
    */
   def unifyStateRowPairAsRawBytes(
       pair: (UnsafeRow, UnsafeRow),
-      colFamilyName: String): InternalRow = {
+      colFamilyName: String,
+      extractor: StatePartitionKeyExtractor): InternalRow = {
     val row = new GenericInternalRow(4)
-    // todo [SPARK-54443]: change keySchema to more specific type after we
-    //  can extract partition key from keySchema
-    row.update(0, pair._1)
+    row.update(0, extractor.partitionKey(pair._1))
     row.update(1, pair._1.getBytes)
     row.update(2, pair._2.getBytes)
     row.update(3, UTF8String.fromString(colFamilyName))
@@ -266,7 +276,9 @@ object SchemaUtil {
       "value_bytes" -> classOf[BinaryType],
       "column_family_name" -> classOf[StringType])
 
-    val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) {
+    val expectedFieldNames = if 
(sourceOptions.internalOnlyReadAllColumnFamilies) {
+      Seq("partition_key", "key_bytes", "value_bytes", "column_family_name")
+    } else if (transformWithStateVariableInfoOpt.isDefined) {
       val stateVarInfo = transformWithStateVariableInfoOpt.get
       val stateVarType = stateVarInfo.stateVariableType
 
@@ -305,8 +317,6 @@ object SchemaUtil {
       }
     } else if (sourceOptions.readChangeFeed) {
       Seq("batch_id", "change_type", "key", "value", "partition_id")
-    } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
-      Seq("partition_key", "key_bytes", "value_bytes", "column_family_name")
     } else {
       Seq("key", "value", "partition_id")
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala
index bbcdb1ff5326..4271cafb2eb7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala
@@ -21,8 +21,9 @@ import scala.collection.mutable
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateKeyValueRowSchemaUtils._
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateVariableUtils.getRowCounterCFName
-import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, 
StateStoreColFamilySchema}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateVariableUtils.{getRowCounterCFName,
 getStateNameFromRowCounterCFName, isRowCounterCFName}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, 
StateStoreColFamilySchema, StateStoreErrors}
 import org.apache.spark.sql.types._
 
 object StateStoreColumnFamilySchemaUtils {
@@ -114,6 +115,60 @@ object StateStoreColumnFamilySchemaUtils {
     org.apache.spark.util.Utils.isTesting && isInternalColFamily(colFamilyName)
   }
 
+  /**
+   * Extracts the base state variable name from internal column family names.
+   * Internal column families are auxiliary data structures (TTL index, min 
expiry index,
+   * count index, row counter) that are associated with a user-defined state 
variable.
+   *
+   * @param colFamilyName The internal column family name (must start with "$")
+   * @return The base state variable name
+   * @throws IllegalArgumentException if the column family name is not a 
recognized internal type
+   */
+  private def getStateNameForInternalCF(colFamilyName: String): String = {
+    if (isTtlColFamilyName(colFamilyName)) {
+      getStateNameFromTtlColFamily(colFamilyName)
+    } else if (isMinExpiryIndexCFName(colFamilyName)) {
+      getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if (isCountIndexCFName(colFamilyName)) {
+      getStateNameFromCountIndexCFName(colFamilyName)
+    } else if (isRowCounterCFName(colFamilyName)) {
+      getStateNameFromRowCounterCFName(colFamilyName)
+    } else if (TimerStateUtils.isTimerCFName(colFamilyName)) {
+      // Return the primary index for timer secondary index column family
+      // because we only store the primary index column family in the
+      // StateMetadataTableEntry.operatorProperies.stateVariables
+      if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
+        TimerStateUtils.getPrimaryIndexFromSecondaryIndexCF(colFamilyName)
+      } else {
+        colFamilyName
+      }
+    } else {
+      throw StateStoreErrors.unknownInternalColumnFamily(colFamilyName)
+    }
+  }
+
+  /**
+   * Extracts the base state variable name from a column family name.
+   *
+   * This is useful for looking up the stateVariableInfo associated with a 
column family,
+   * since stateVariableInfo is only stored for primary/user-facing column 
families.
+   *
+   * Returns:
+   *   - For internal CFs (e.g., $ttl_*, $min_*, $count_*): the associated 
state variable name
+   *   - For timer secondary index CFs: the corresponding primary index timer 
CF name
+   *   - For all other CFs (regular state variables, timer primary index): the 
name as-is
+   *
+   * @param colFamilyName The column family name
+   * @return The base state variable name for stateVariableInfo lookup
+   */
+  def getBaseStateName(colFamilyName: String): String = {
+    if (isInternalColFamily(colFamilyName)) {
+      getStateNameForInternalCF(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
   def getValueStateSchema[T](
       stateName: String,
       keyEncoder: ExpressionEncoder[Any],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala
index 5b4100e9c256..64d73ce5f545 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala
@@ -63,6 +63,11 @@ object TransformWithStateVariableUtils {
   def isRowCounterCFName(colFamilyName: String): Boolean = {
     colFamilyName.startsWith(ROW_COUNTER_CF_PREFIX)
   }
+
+  def getStateNameFromRowCounterCFName(colFamilyName: String): String = {
+    require(isRowCounterCFName(colFamilyName))
+    colFamilyName.substring(ROW_COUNTER_CF_PREFIX.length)
+  }
 }
 
 // Enum of possible State Variable types
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
index 75624c9af9ef..101265fd8d83 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
@@ -64,6 +64,12 @@ object TimerStateUtils {
     assert(isTimerCFName(colFamilyName), s"Column family name must be for a 
timer: $colFamilyName")
     colFamilyName.endsWith(TIMESTAMP_TO_KEY_CF)
   }
+
+  def getPrimaryIndexFromSecondaryIndexCF(colFamilyName: String): String = {
+    assert(isTimerSecondaryIndexCF(colFamilyName),
+      s"Column family name must be for a timer secondary index: 
$colFamilyName")
+    colFamilyName.replace(TIMESTAMP_TO_KEY_CF, KEY_TO_TIMESTAMP_CF)
+  }
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index cd29f8f30f6e..fc9c256735b2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -291,6 +291,11 @@ object StateStoreErrors {
   def stateStoreCheckpointIdsNotSupported(msg: String): 
StateStoreCheckpointIdsNotSupported = {
     new StateStoreCheckpointIdsNotSupported(msg)
   }
+
+  def unknownInternalColumnFamily(colFamilyName: String):
+      StateStoreUnknownInternalColumnFamily = {
+    new StateStoreUnknownInternalColumnFamily(colFamilyName)
+  }
 }
 
 trait ConvertableToCannotLoadStoreError {
@@ -637,3 +642,8 @@ class StateStoreAutoSnapshotRepairFailed(
       "selectedSnapshots" -> selectedSnapshots.mkString(","),
       "eligibleSnapshots" -> eligibleSnapshots.mkString(",")),
     cause)
+
+class StateStoreUnknownInternalColumnFamily(colFamilyName: String)
+  extends SparkRuntimeException(
+    errorClass = "STATE_STORE_UNKNOWN_INTERNAL_COLUMN_FAMILY",
+    messageParameters = Map("colFamilyName" -> colFamilyName))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
index cb864da48731..4eb571804935 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
@@ -88,13 +88,22 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
   /**
    * Compares normal read data with bytes read data for a specific column 
family.
    * Converts normal rows to bytes then compares with bytes read.
+   *
+   * @param normalDf Normal read data with columns (partition_id, key, value)
+   * @param bytesDf Bytes read data with columns (partition_key, key_bytes, 
value_bytes, cf_name)
+   * @param columnFamily The column family name to filter on
+   * @param keySchema Schema of the full key
+   * @param valueSchema Schema of the value
+   * @param partitionKeyExtractor Function to extract partition key from full 
key Row.
+   *                              If None, assumes partition key equals the 
full key.
    */
   private def compareNormalAndBytesData(
       normalDf: Array[Row],
       bytesDf: Array[Row],
       columnFamily: String,
       keySchema: StructType,
-      valueSchema: StructType): Unit = {
+      valueSchema: StructType,
+      partitionKeyExtractor: Option[Row => Row] = None): Unit = {
 
     // Filter bytes data for the specified column family and extract raw bytes 
directly
     val filteredBytesData = bytesDf.filter { row =>
@@ -114,12 +123,17 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
     val keyConverter = 
CatalystTypeConverters.createToCatalystConverter(keySchema)
     val valueConverter = 
CatalystTypeConverters.createToCatalystConverter(valueSchema)
 
-    // Convert normal data to bytes
-    val normalAsBytes = normalDf.toSeq.map { row =>
+    // Convert normal data to (partitionKeyStruct, keyBytes, valueBytes)
+    val normalData = normalDf.toSeq.map { row =>
       val key = row.getStruct(1)
       val value = if (row.isNullAt(2)) null else row.getStruct(2)
 
-      // Convert key to InternalRow, then to UnsafeRow, then get bytes
+      // Extract partition key - use extractor if provided, otherwise use full 
key
+      val partitionKey: Row = partitionKeyExtractor match {
+        case Some(extractor) => extractor(key)
+        case None => key
+      }
+      // Convert key to bytes
       val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow]
       val keyUnsafeRow = keyProjection(keyInternalRow)
       // IMPORTANT: Must clone the bytes array since getBytes() returns a 
reference
@@ -137,26 +151,29 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         valueUnsafeRow.getBytes.clone()
       }
 
-      (keyBytes, valueBytes)
+      (keyBytes, valueBytes, partitionKey)
     }
 
-    // Extract raw bytes from bytes read data (no 
deserialization/reserialization)
-    val bytesAsBytes = filteredBytesData.map { row =>
+    // Extract (partitionKeyStruct, keyBytes, valueBytes) from bytes read data
+    val bytesData = filteredBytesData.map { row =>
+      val partitionKey = row.getStruct(0)
       val keyBytes = row.getAs[Array[Byte]](1)
       val valueBytes = row.getAs[Array[Byte]](2)
-      (keyBytes, valueBytes)
+      (keyBytes, valueBytes, partitionKey)
     }
 
-    // Sort both for comparison (since Set equality doesn't work well with 
byte arrays)
-    val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
-    val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
+    // Sort both for comparison by key and value bytes
+    val normalSorted = normalData.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
+    val bytesSorted = bytesData.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
 
     assert(normalSorted.length == bytesSorted.length,
       s"Size mismatch: normal has ${normalSorted.length}, bytes has 
${bytesSorted.length}")
 
-    // Compare each pair
+    // Compare each tuple (partitionKeyStruct, keyBytes, valueBytes)
     normalSorted.zip(bytesSorted).zipWithIndex.foreach {
-      case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) =>
+      case (((normalKey, normalValue, normalPartitionKey),
+             (bytesKey, bytesValue, bytesPartitionKey)), idx) =>
+        assert(normalPartitionKey == bytesPartitionKey)
         assert(Arrays.equals(normalKey, bytesKey),
           s"Key mismatch at index $idx:\n" +
             s"  Normal: ${normalKey.mkString("[", ",", "]")}\n" +
@@ -179,7 +196,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       keySchema: StructType,
       valueSchema: StructType,
       extraOptions: Map[String, String] = Map.empty,
-      selectExprs: Seq[String] = Seq("partition_id", "key", "value")): Unit = {
+      selectExprs: Seq[String] = Seq("partition_id", "key", "value"),
+      partitionKeyExtractor: Option[Row => Row] = None): Unit = {
     var reader = spark.read
       .format("statestore")
       .option(StateSourceOptions.PATH, checkpointDir)
@@ -194,7 +212,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       allBytesData,
       stateVarName,
       keySchema,
-      valueSchema)
+      valueSchema,
+      partitionKeyExtractor)
   }
 
   /**
@@ -252,7 +271,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       allBytesData,
       s"$$${timerPrefix}Timers_keyToTimestamp",
       keyToTimestampSchema,
-      dummySchema)
+      dummySchema,
+      partitionKeyExtractor = Some(compositeKey => compositeKey.getStruct(0)))
 
     val timestampToKeyNormalDf = timerBaseDf.selectExpr(
       
TimerTestUtils.getTimerSelectExpressions(s"$$${timerPrefix}Timers_timestampToKey"):
 _*)
@@ -261,7 +281,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       allBytesData,
       s"$$${timerPrefix}Timers_timestampToKey",
       timestampToKeySchema,
-      dummySchema)
+      dummySchema,
+      partitionKeyExtractor = Some(compositeKey => compositeKey.getStruct(1)))
   }
 
   /**
@@ -278,7 +299,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       tempDir: String,
       keySchema: StructType,
       valueSchema: StructType,
-      storeName: Option[String] = None): Unit = {
+      storeName: Option[String] = None,
+      partitionKeyExtractor: Option[Row => Row] = None): Unit = {
     val normalDf = getNormalReadDf(tempDir, storeName)
     val bytesDf = getBytesReadDf(tempDir, storeName)
 
@@ -288,7 +310,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       bytesDf.collect(),
       StateStore.DEFAULT_COL_FAMILY_NAME,
       keySchema,
-      valueSchema)
+      valueSchema,
+      partitionKeyExtractor)
   }
 
   /**
@@ -307,7 +330,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       colFamilyName: String,
       sharedBytesDf: DataFrame,
       keySchema: StructType,
-      valueSchema: StructType): Unit = {
+      valueSchema: StructType,
+      partitionKeyExtractor: Option[Row => Row] = None): Unit = {
     val normalDf = getNormalReadDf(tempDir, Option(colFamilyName))
 
     compareNormalAndBytesData(
@@ -315,7 +339,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       sharedBytesDf.collect(),
       colFamilyName,
       keySchema,
-      valueSchema)
+      valueSchema,
+      partitionKeyExtractor)
   }
 
   // Run all tests with both changelog checkpointing enabled and disabled
@@ -399,7 +424,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
 
         val (keySchema, valueSchema) = SessionWindowTestUtils.getSchemas()
 
-        validateStateStore(tempDir.getAbsolutePath, keySchema, valueSchema)
+        validateStateStore(tempDir.getAbsolutePath, keySchema, valueSchema,
+          partitionKeyExtractor = Some(row => Row(row.getString(0))))
       }
     }
 
@@ -414,13 +440,12 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
             val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
             val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
 
-            validateBytesReadDfSchema(bytesDf)
-            compareNormalAndBytesData(
-              normalData, bytesDf.collect(), "default", keySchema, valueSchema)
-          }
+          validateBytesReadDfSchema(bytesDf)
+          compareNormalAndBytesData(
+            normalData, bytesDf.collect(), "default", keySchema, valueSchema)
         }
       }
-    )
+    })
 
     Seq(1, 2).foreach(version =>
       testWithChangelogConfig(s"stream-stream join, state ver $version") {
@@ -448,7 +473,9 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
                 tempDir.getAbsolutePath,
                 keyWithIndexKeySchema,
                 keyWithIndexValueSchema,
-                Some(storeName))
+                Some(storeName),
+                partitionKeyExtractor = Some(compositeKey =>
+                  Row(compositeKey.getInt(0))))
             }
           }
         }
@@ -514,7 +541,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
           tempDir.getAbsolutePath, allBytesData,
           MultiStateVarProcessorTestUtils.ITEMS_MAP, mapKeySchema, 
mapValueSchema,
           selectExprs = MultiStateVarProcessorTestUtils.getSelectExpressions(
-            MultiStateVarProcessorTestUtils.ITEMS_MAP))
+            MultiStateVarProcessorTestUtils.ITEMS_MAP),
+          partitionKeyExtractor = Some(compositeKey => 
compositeKey.getStruct(0)))
       }
     }
 
@@ -616,13 +644,18 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         val (ttlIndexKeySchema, ttlValueSchema) = 
schemas(TTLProcessorUtils.LIST_STATE_TTL_INDEX)
         val (_, minExpiryValueSchema) = 
schemas(TTLProcessorUtils.LIST_STATE_MIN)
         val (_, countValueSchema) = schemas(TTLProcessorUtils.LIST_STATE_COUNT)
+        val ttlColFamilyPartitionKeyExtractor: Option[Row => Row] =
+          Some(compositeKey => compositeKey.getStruct(1))
         val columnFamilyAndKeyValueSchema = Seq(
-          (TTLProcessorUtils.LIST_STATE_TTL_INDEX, ttlIndexKeySchema, 
ttlValueSchema),
-          (TTLProcessorUtils.LIST_STATE_MIN, groupByKeySchema, 
minExpiryValueSchema),
-          (TTLProcessorUtils.LIST_STATE_COUNT, groupByKeySchema, 
countValueSchema)
+          (TTLProcessorUtils.LIST_STATE_TTL_INDEX,
+            ttlIndexKeySchema, ttlValueSchema,
+            ttlColFamilyPartitionKeyExtractor),
+          (TTLProcessorUtils.LIST_STATE_MIN, groupByKeySchema, 
minExpiryValueSchema, None),
+          (TTLProcessorUtils.LIST_STATE_COUNT, groupByKeySchema, 
countValueSchema, None)
         )
         columnFamilyAndKeyValueSchema.foreach(pair => {
-          validateColumnFamily(tempDir.getAbsolutePath, pair._1, bytesDf, 
pair._2, pair._3)
+          validateColumnFamily(
+            tempDir.getAbsolutePath, pair._1, bytesDf, pair._2, pair._3, 
pair._4)
         })
       }
     }
@@ -665,12 +698,14 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         readAndValidateStateVar(
           tempDir.getAbsolutePath, allBytesData,
           stateVarName = TTLProcessorUtils.MAP_STATE, compositeKeySchema, 
mapStateValueSchema,
-          selectExprs = 
TTLProcessorUtils.getTTLSelectExpressions(TTLProcessorUtils.MAP_STATE))
+          selectExprs = 
TTLProcessorUtils.getTTLSelectExpressions(TTLProcessorUtils.MAP_STATE),
+          partitionKeyExtractor = Some(compositeKey => 
compositeKey.getStruct(0)))
 
         // Validate $ttl_mapState column family
         readAndValidateStateVar(
           tempDir.getAbsolutePath, allBytesData,
-          stateVarName = TTLProcessorUtils.MAP_STATE_TTL_INDEX, 
ttlIndexKeySchema, dummyValueSchema)
+          stateVarName = TTLProcessorUtils.MAP_STATE_TTL_INDEX, 
ttlIndexKeySchema, dummyValueSchema,
+          partitionKeyExtractor = Some(ttlKey => 
ttlKey.getStruct(1).getStruct(0)))
       }
     }
 
@@ -728,7 +763,8 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
           allBytesData,
           TTLProcessorUtils.VALUE_STATE_TTL_INDEX,
           ttlIndexKeySchema,
-          dummyValueSchema)
+          dummyValueSchema,
+          partitionKeyExtractor = Some(ttlKey => ttlKey.getStruct(1)))
       }
     }
 
@@ -762,7 +798,9 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
               colFamilyName,
               stateBytesDf,
               keyWithIndexKeySchema,
-              keyWithIndexValueSchema)
+              keyWithIndexValueSchema,
+              partitionKeyExtractor = Some(compositeKey =>
+                Row(compositeKey.getInt(0))))
           }
         }
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to