This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 63b97c6ad82a [SPARK-46979][SS] Add support for specifying key and 
value encoder separately and also for each col family in RocksDB state store 
provider
63b97c6ad82a is described below

commit 63b97c6ad82afac71afcd64117346b6e0bda72bb
Author: Anish Shrigondekar <anish.shrigonde...@databricks.com>
AuthorDate: Wed Feb 14 06:19:48 2024 +0900

    [SPARK-46979][SS] Add support for specifying key and value encoder 
separately and also for each col family in RocksDB state store provider
    
    ### What changes were proposed in this pull request?
    Add support for specifying key and value encoder separately and also for 
each col family in RocksDB state store provider
    
    ### Why are the changes needed?
    This change allows us to specify encoder for key/values separately and 
avoid encoding additional bytes. Also, it allows us to set schemas/encoders for 
individual column families, which will be required for future changes related 
to transformWithState operator (listState/timer changes etc)
    
    We are refactoring a bit here given the upcoming changes. so we are 
proposing to split key and value encoders.
    Key encoders can be of 2 types:
    - with prefix scan
    - without prefix scan
    
    Value encoders can also eventually be of 2 types:
    - single value
    - multiple values (used for list state)
    
    And we now also allow setting schema and getting encoder for each column 
family.
    So after the change, we can potentially allow something like this:
    - col family 1 - with keySchema with prefix scan and valueSchema with 
single value and binary type
    - col family 2 - with keySchema without prefix scan and valueSchema with 
multiple values
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests
    ```
    [info] Run completed in 3 minutes, 5 seconds.
    [info] Total number of tests run: 286
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 286, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45038 from anishshri-db/task/SPARK-46979.
    
    Authored-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../streaming/StatefulProcessorHandleImpl.scala    |   1 -
 .../sql/execution/streaming/ValueStateImpl.scala   |  20 ++--
 .../state/HDFSBackedStateStoreProvider.scala       |   6 +-
 .../streaming/state/RocksDBStateEncoder.scala      | 106 +++++++++++----------
 .../state/RocksDBStateStoreProvider.scala          |  58 ++++++++---
 .../sql/execution/streaming/state/StateStore.scala |   6 +-
 .../streaming/state/MemoryStateStore.scala         |   7 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   |  24 +++++
 8 files changed, 151 insertions(+), 77 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index fed18fc7e458..62c97d11c926 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -115,7 +115,6 @@ class StatefulProcessorHandleImpl(
   override def getValueState[T](stateName: String): ValueState[T] = {
     verify(currState == CREATED, s"Cannot create state variable with 
name=$stateName after " +
       "initialization is complete")
-    store.createColFamilyIfAbsent(stateName)
     val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
     resultState
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index 11ae7f65b43d..c1d807144df6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.types._
  * @param store - reference to the StateStore instance to be used for storing 
state
  * @param stateName - name of logical state partition
  * @param keyEnc - Spark SQL encoder for key
- * @tparam K - data type of key
  * @tparam S - data type of object that will be stored
  */
 class ValueStateImpl[S](
@@ -40,6 +39,16 @@ class ValueStateImpl[S](
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {
 
+  private val schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
+  private val valueEncoder = UnsafeProjection.create(schemaForValueRow)
+
+  store.createColFamilyIfAbsent(stateName, schemaForKeyRow, numColsPrefixKey = 
0,
+    schemaForValueRow)
+
   // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
   // add caching for some of these calls.
   private def encodeKey(): UnsafeRow = {
@@ -48,20 +57,13 @@ class ValueStateImpl[S](
       throw StateStoreErrors.implicitKeyNotFound(stateName)
     }
 
-    val toRow = keyExprEnc.createSerializer()
-    val keyByteArr = toRow
-      .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
-
-    val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
-    val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+    val keyByteArr = 
keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
     val keyRow = keyEncoder(InternalRow(keyByteArr))
     keyRow
   }
 
   private def encodeValue(value: S): UnsafeRow = {
-    val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
     val valueByteArr = 
SerializationUtils.serialize(value.asInstanceOf[Serializable])
-    val valueEncoder = UnsafeProjection.create(schemaForValueRow)
     val valueRow = valueEncoder(InternalRow(valueByteArr))
     valueRow
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index dd04053c5471..b23c83f625d6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -114,7 +114,11 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
     override def id: StateStoreId = 
HDFSBackedStateStoreProvider.this.stateStoreId
 
-    override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
+    override def createColFamilyIfAbsent(
+        colFamilyName: String,
+        keySchema: StructType,
+        numColsPrefixKey: Int,
+        valueSchema: StructType): Unit = {
       throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3193")
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 81755e52968b..be1bb4689507 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -22,31 +22,35 @@ import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.
 import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.unsafe.Platform
 
-sealed trait RocksDBStateEncoder {
+sealed trait RocksDBKeyStateEncoder {
   def supportPrefixKeyScan: Boolean
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
 
   def encodeKey(row: UnsafeRow): Array[Byte]
-  def encodeValue(row: UnsafeRow): Array[Byte]
-
   def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+}
+
+sealed trait RocksDBValueStateEncoder {
+  def encodeValue(row: UnsafeRow): Array[Byte]
   def decodeValue(valueBytes: Array[Byte]): UnsafeRow
-  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
 }
 
 object RocksDBStateEncoder {
-  def getEncoder(
+  def getKeyEncoder(
       keySchema: StructType,
-      valueSchema: StructType,
-      numColsPrefixKey: Int): RocksDBStateEncoder = {
+      numColsPrefixKey: Int): RocksDBKeyStateEncoder = {
     if (numColsPrefixKey > 0) {
-      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+      new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)
     } else {
-      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+      new NoPrefixKeyStateEncoder(keySchema)
     }
   }
 
+  def getValueEncoder(valueSchema: StructType): RocksDBValueStateEncoder = {
+    new SingleValueStateEncoder(valueSchema)
+  }
+
   /**
    * Encode the UnsafeRow of N bytes as a N+1 byte array.
    * @note This creates a new byte array and memcopies the UnsafeRow to the 
new array.
@@ -86,10 +90,15 @@ object RocksDBStateEncoder {
   }
 }
 
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports prefix scan
+ *
+ * @param keySchema - schema of the key to be encoded
+ * @param numColsPrefixKey - number of columns to be used for prefix key
+ */
 class PrefixKeyScanStateEncoder(
     keySchema: StructType,
-    valueSchema: StructType,
-    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+    numColsPrefixKey: Int) extends RocksDBKeyStateEncoder {
 
   import RocksDBStateEncoder._
 
@@ -120,8 +129,6 @@ class PrefixKeyScanStateEncoder(
 
   // Reusable objects
   private val joinedRowOnKey = new JoinedRow()
-  private val valueRow = new UnsafeRow(valueSchema.size)
-  private val rowTuple = new UnsafeRowPair()
 
   override def encodeKey(row: UnsafeRow): Array[Byte] = {
     val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
@@ -140,8 +147,6 @@ class PrefixKeyScanStateEncoder(
     encodedBytes
   }
 
-  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
-
   override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
     val prefixKeyEncodedLen = Platform.getInt(keyBytes, 
Platform.BYTE_ARRAY_OFFSET)
     val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
@@ -163,10 +168,6 @@ class PrefixKeyScanStateEncoder(
     
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
   }
 
-  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
-    decodeToUnsafeRow(valueBytes, valueRow)
-  }
-
   override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
     prefixKeyProjection(key)
   }
@@ -180,14 +181,12 @@ class PrefixKeyScanStateEncoder(
     prefix
   }
 
-  override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
-    rowTuple.withRows(decodeKey(byteArrayTuple.key), 
decodeValue(byteArrayTuple.value))
-  }
-
   override def supportPrefixKeyScan: Boolean = true
 }
 
 /**
+ * RocksDB Key Encoder for UnsafeRow that does not support prefix key scan.
+ *
  * Encodes/decodes UnsafeRows to versioned byte arrays.
  * It uses the first byte of the generated byte array to store the version the 
describes how the
  * row is encoded in the rest of the byte array. Currently, the default 
version is 0,
@@ -197,20 +196,16 @@ class PrefixKeyScanStateEncoder(
  *    (offset 0 is the version byte of value 0). That is, if the unsafe row 
has N bytes,
  *    then the generated array byte will be N+1 bytes.
  */
-class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
-  extends RocksDBStateEncoder {
+class NoPrefixKeyStateEncoder(keySchema: StructType)
+  extends RocksDBKeyStateEncoder {
 
   import RocksDBStateEncoder._
 
   // Reusable objects
   private val keyRow = new UnsafeRow(keySchema.size)
-  private val valueRow = new UnsafeRow(valueSchema.size)
-  private val rowTuple = new UnsafeRowPair()
 
   override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
-  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
-
   /**
    * Decode byte array for a key to a UnsafeRow.
    * @note The UnsafeRow returned is reused across calls, and the UnsafeRow 
just points to
@@ -220,26 +215,6 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, 
valueSchema: StructType)
     decodeToUnsafeRow(keyBytes, keyRow)
   }
 
-  /**
-   * Decode byte array for a value to a UnsafeRow.
-   *
-   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow 
just points to
-   *       the given byte array.
-   */
-  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
-    decodeToUnsafeRow(valueBytes, valueRow)
-  }
-
-  /**
-   * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
-   *
-   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow 
just points to
-   *       the given byte array.
-   */
-  override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
-    rowTuple.withRows(decodeKey(byteArrayTuple.key), 
decodeValue(byteArrayTuple.value))
-  }
-
   override def supportPrefixKeyScan: Boolean = false
 
   override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
@@ -250,3 +225,36 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, 
valueSchema: StructType)
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
 }
+
+/**
+ * RocksDB Value Encoder for UnsafeRow that only supports single value.
+ *
+ * Encodes/decodes UnsafeRows to versioned byte arrays.
+ * It uses the first byte of the generated byte array to store the version the 
describes how the
+ * row is encoded in the rest of the byte array. Currently, the default 
version is 0,
+ *
+ * VERSION 0:  [ VERSION (1 byte) | ROW (N bytes) ]
+ *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
+ *    (offset 0 is the version byte of value 0). That is, if the unsafe row 
has N bytes,
+ *    then the generated array byte will be N+1 bytes.
+ */
+class SingleValueStateEncoder(valueSchema: StructType)
+  extends RocksDBValueStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  // Reusable objects
+  private val valueRow = new UnsafeRow(valueSchema.size)
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  /**
+   * Decode byte array for a value to a UnsafeRow.
+   *
+   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow 
just points to
+   *       the given byte array.
+   */
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(valueBytes, valueRow)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index e469fd4fe1c2..0c3487ba4dd7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -48,16 +48,26 @@ private[sql] class RocksDBStateStoreProvider
 
     override def version: Long = lastVersion
 
-    override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
+    override def createColFamilyIfAbsent(
+        colFamilyName: String,
+        keySchema: StructType,
+        numColsPrefixKey: Int,
+        valueSchema: StructType): Unit = {
       verify(colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME,
         s"Failed to create column family with reserved_name=$colFamilyName")
+      verify(useColumnFamilies, "Column families are not supported in this 
store")
       rocksDB.createColFamilyIfAbsent(colFamilyName)
+      keyValueEncoderMap.putIfAbsent(colFamilyName,
+        (RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
+         RocksDBStateEncoder.getValueEncoder(valueSchema)))
     }
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
       verify(key != null, "Key cannot be null")
-      val value = encoder.decodeValue(rocksDB.get(encoder.encodeKey(key), 
colFamilyName))
-      if (!isValidated && value != null) {
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      val value = kvEncoder._2.decodeValue(
+        rocksDB.get(kvEncoder._1.encodeKey(key), colFamilyName))
+      if (!isValidated && value != null && !useColumnFamilies) {
         StateStoreProvider.validateStateRowFormat(
           key, keySchema, value, valueSchema, storeConf)
         isValidated = true
@@ -69,19 +79,25 @@ private[sql] class RocksDBStateStoreProvider
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
       verify(key != null, "Key cannot be null")
       require(value != null, "Cannot put a null value")
-      rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value), 
colFamilyName)
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      rocksDB.put(kvEncoder._1.encodeKey(key),
+        kvEncoder._2.encodeValue(value), colFamilyName)
     }
 
     override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
       verify(key != null, "Key cannot be null")
-      rocksDB.remove(encoder.encodeKey(key), colFamilyName)
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
     }
 
     override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      val rowPair = new UnsafeRowPair()
       rocksDB.iterator(colFamilyName).map { kv =>
-        val rowPair = encoder.decode(kv)
-        if (!isValidated && rowPair.value != null) {
+        rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
+          kvEncoder._2.decodeValue(kv.value))
+        if (!isValidated && rowPair.value != null && !useColumnFamilies) {
           StateStoreProvider.validateStateRowFormat(
             rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
           isValidated = true
@@ -92,10 +108,17 @@ private[sql] class RocksDBStateStoreProvider
 
     override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
       Iterator[UnsafeRowPair] = {
-      require(encoder.supportPrefixKeyScan, "Prefix scan requires setting 
prefix key!")
-
-      val prefix = encoder.encodePrefixKey(prefixKey)
-      rocksDB.prefixScan(prefix, colFamilyName).map(kv => encoder.decode(kv))
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      require(kvEncoder._1.supportPrefixKeyScan,
+        "Prefix scan requires setting prefix key!")
+
+      val prefix = kvEncoder._1.encodePrefixKey(prefixKey)
+      val rowPair = new UnsafeRowPair()
+      rocksDB.prefixScan(prefix, colFamilyName).map { kv =>
+        rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
+          kvEncoder._2.decodeValue(kv.value))
+        rowPair
+      }
     }
 
     override def commit(): Long = synchronized {
@@ -191,8 +214,10 @@ private[sql] class RocksDBStateStoreProvider
     def dbInstance(): RocksDB = rocksDB
 
     /** Remove column family if exists */
-     override def removeColFamilyIfExists(colFamilyName: String): Unit = {
-       rocksDB.removeColFamilyIfExists(colFamilyName)
+    override def removeColFamilyIfExists(colFamilyName: String): Unit = {
+      verify(useColumnFamilies, "Column families are not supported in this 
store")
+      rocksDB.removeColFamilyIfExists(colFamilyName)
+      keyValueEncoderMap.remove(colFamilyName)
     }
   }
 
@@ -215,7 +240,9 @@ private[sql] class RocksDBStateStoreProvider
       (keySchema.length > numColsPrefixKey), "The number of columns in the key 
must be " +
       "greater than the number of columns for prefix key!")
 
-    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, 
numColsPrefixKey)
+    keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
+      (RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
+       RocksDBStateEncoder.getValueEncoder(valueSchema)))
 
     rocksDB // lazy initialization
   }
@@ -287,7 +314,8 @@ private[sql] class RocksDBStateStoreProvider
       useColumnFamilies)
   }
 
-  @volatile private var encoder: RocksDBStateEncoder = _
+  private val keyValueEncoderMap = new 
java.util.concurrent.ConcurrentHashMap[String,
+    (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]
 
   private def verify(condition: => Boolean, msg: String): Unit = {
     if (!condition) { throw new IllegalStateException(msg) }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 4b409b8a66b7..7207a4746196 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -112,7 +112,11 @@ trait StateStore extends ReadStateStore {
   /**
    * Create column family with given name, if absent.
    */
-  def createColFamilyIfAbsent(colFamilyName: String): Unit
+  def createColFamilyIfAbsent(
+      colFamilyName: String,
+      keySchema: StructType,
+      numColsPrefixKey: Int,
+      valueSchema: StructType): Unit
 
   /**
    * Put a new non-null value for a non-null key. Implementations must be 
aware that the UnsafeRows
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index 02052d307f41..b7a738786e3f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
 
 class MemoryStateStore extends StateStore() {
   import scala.jdk.CollectionConverters._
@@ -29,7 +30,11 @@ class MemoryStateStore extends StateStore() {
     map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, 
e.getValue) }
   }
 
-  override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
+  override def createColFamilyIfAbsent(
+      colFamilyName: String,
+      keySchema: StructType,
+      numColsPrefixKey: Int,
+      valueSchema: StructType): Unit = {
     throw 
StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider")
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index b820247f6282..f2811a23fd8a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -158,6 +158,30 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  testWithColumnFamilies("rocksdb key and value schema encoders for column 
families",
+    TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
+    val testColFamily = "testState"
+
+    tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider =>
+      val store = provider.getStore(0)
+      if (colFamiliesEnabled) {
+        store.createColFamilyIfAbsent(testColFamily, keySchema, 
numColsPrefixKey = 0, valueSchema)
+        val keyRow1 = dataToKeyRow("a", 0)
+        val valueRow1 = dataToValueRow(1)
+        store.put(keyRow1, valueRow1, colFamilyName = testColFamily)
+        assert(valueRowToData(store.get(keyRow1, colFamilyName = 
testColFamily)) === 1)
+        store.remove(keyRow1, colFamilyName = testColFamily)
+        assert(store.get(keyRow1, colFamilyName = testColFamily) === null)
+      }
+      val keyRow2 = dataToKeyRow("b", 0)
+      val valueRow2 = dataToValueRow(2)
+      store.put(keyRow2, valueRow2)
+      assert(valueRowToData(store.get(keyRow2)) === 2)
+      store.remove(keyRow2)
+      assert(store.get(keyRow2) === null)
+    }
+  }
+
   override def newStoreProvider(): RocksDBStateStoreProvider = {
     newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to