This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 60806c63d97b [SPARK-47746] Implement ordinal-based range encoding in 
the RocksDBStateEncoder
60806c63d97b is described below

commit 60806c63d97bc35f62a049b2185eb921217904c4
Author: Neil Ramaswamy <neil.ramasw...@databricks.com>
AuthorDate: Mon Apr 8 17:39:18 2024 +0900

    [SPARK-47746] Implement ordinal-based range encoding in the 
RocksDBStateEncoder
    
    ### What changes were proposed in this pull request?
    
    The RocksDBStateEncoder now implements range projection by reading a list 
of ordering ordinals, and using that to project certain columns, in big-endian, 
to the front of the `Array[Byte]` encoded rows returned by the encoder.
    
    ### Why are the changes needed?
    
    StateV2 implementations (and other state-related operators) project certain 
columns to the front of `UnsafeRow`s, and then rely on the RocksDBStateEncoder 
to range-encode those columns. We can avoid the initial projection by just 
passing the RocksDBStateEncoder the ordinals to encode at the front. This 
should avoid any GC or codegen overheads associated with projection.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UTs. All existing UTs should pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes
    
    Closes #45905 from neilramaswamy/spark-47746.
    
    Authored-by: Neil Ramaswamy <neil.ramasw...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |   2 +-
 docs/sql-error-conditions.md                       |   2 +-
 .../spark/sql/execution/streaming/TTLState.scala   |   2 +-
 .../sql/execution/streaming/TimerStateImpl.scala   |   2 +-
 .../streaming/state/RocksDBStateEncoder.scala      |  87 ++++++----
 .../sql/execution/streaming/state/StateStore.scala |   7 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   | 184 ++++++++++++++++++---
 .../streaming/state/StateStoreSuite.scala          |   2 +-
 8 files changed, 228 insertions(+), 60 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index f28adaf40230..c3a01e9dcd90 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3630,7 +3630,7 @@
   },
   "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN" : {
     "message" : [
-      "Incorrect number of ordering columns=<numOrderingCols> for range scan 
encoder. Ordering columns cannot be zero or greater than num of schema columns."
+      "Incorrect number of ordering ordinals=<numOrderingCols> for range scan 
encoder. The number of ordering ordinals cannot be zero or greater than number 
of schema columns."
     ],
     "sqlState" : "42802"
   },
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index d8261b8c2765..1887af2e814b 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2236,7 +2236,7 @@ Please only use the StatefulProcessor within the 
transformWithState operator.
 
 [SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
 
-Incorrect number of ordering columns=`<numOrderingCols>` for range scan 
encoder. Ordering columns cannot be zero or greater than num of schema columns.
+Incorrect number of ordering ordinals=`<numOrderingCols>` for range scan 
encoder. The number of ordering ordinals cannot be zero or greater than number 
of schema columns.
 
 ### STATE_STORE_INCORRECT_NUM_PREFIX_COLS_FOR_PREFIX_SCAN
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
index 0ae93549b731..f64c8cc44555 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -93,7 +93,7 @@ abstract class SingleKeyTTLStateImpl(
     
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
 
   store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, 
TTL_VALUE_ROW_SCHEMA,
-    RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, 1), isInternal = true)
+    RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = 
true)
 
   def upsertTTLForStateKey(
       expirationMs: Long,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 8d410b677c84..55acc4953c50 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -91,7 +91,7 @@ class TimerStateImpl(
 
   private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
   store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
-    schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
+    schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 
Seq(0)),
     useMultipleValuesPerKey = false, isInternal = true)
 
   private def getGroupingKey(cfName: String): Any = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index e9b910a76148..80c228d15334 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -52,8 +52,8 @@ object RocksDBStateEncoder {
       case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
         new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)
 
-      case RangeKeyScanStateEncoderSpec(keySchema, numOrderingCols) =>
-        new RangeKeyScanStateEncoder(keySchema, numOrderingCols)
+      case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+        new RangeKeyScanStateEncoder(keySchema, orderingOrdinals)
 
       case _ =>
         throw new IllegalArgumentException(s"Unsupported key state encoder 
spec: " +
@@ -204,8 +204,8 @@ class PrefixKeyScanStateEncoder(
 /**
  * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size 
fields
  *
- * To encode a row for range scan, we first project the first numOrderingCols 
needed
- * for the range scan into an UnsafeRow; we then rewrite that UnsafeRow's 
fields in BIG_ENDIAN
+ * To encode a row for range scan, we first project the orderingOrdinals from 
the oridinal
+ * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's 
fields in BIG_ENDIAN
  * to allow for scanning keys in sorted order using the byte-wise comparison 
method that
  * RocksDB uses.
  *
@@ -213,9 +213,9 @@ class PrefixKeyScanStateEncoder(
  * We then effectively join these two UnsafeRows together, and finally take 
those bytes
  * to get the resulting row.
  *
- * We cannot support variable sized fields given the UnsafeRow format which 
stores variable
- * sized fields as offset and length pointers to the actual values, thereby 
changing the required
- * ordering.
+ * We cannot support variable sized fields in the range scan because the 
UnsafeRow format
+ * stores variable sized fields as offset and length pointers to the actual 
values,
+ * thereby changing the required ordering.
  *
  * Note that we also support "null" values being passed for these fixed size 
fields. We prepend
  * a single byte to indicate whether the column value is null or not. We 
cannot change the
@@ -229,16 +229,19 @@ class PrefixKeyScanStateEncoder(
  * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
  *
  * @param keySchema - schema of the key to be encoded
- * @param numOrderingCols - number of columns to be used for range scan
+ * @param orderingOrdinals - the ordinals for which the range scan is 
constructed
  */
 class RangeKeyScanStateEncoder(
     keySchema: StructType,
-    numOrderingCols: Int) extends RocksDBKeyStateEncoder {
+    orderingOrdinals: Seq[Int]) extends RocksDBKeyStateEncoder {
 
   import RocksDBStateEncoder._
 
-  private val rangeScanKeyFieldsWithIdx: Seq[(StructField, Int)] = {
-    keySchema.zipWithIndex.take(numOrderingCols)
+  private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+    orderingOrdinals.map { ordinal =>
+      val field = keySchema(ordinal)
+      (field, ordinal)
+    }
   }
 
   private def isFixedSize(dataType: DataType): Boolean = dataType match {
@@ -248,34 +251,56 @@ class RangeKeyScanStateEncoder(
   }
 
   // verify that only fixed sized columns are used for ordering
-  rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+  rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
     if (!isFixedSize(field.dataType)) {
       // NullType is technically fixed size, but not supported for ordering
       if (field.dataType == NullType) {
-        throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, 
idx.toString)
+        throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, 
ordinal.toString)
       } else {
-        throw 
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, idx.toString)
+        throw 
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, 
ordinal.toString)
       }
     }
   }
 
-  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
-    keySchema.zipWithIndex.drop(numOrderingCols)
+  private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+    0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
+      val field = keySchema(ordinal)
+      (field, ordinal)
+    }
   }
 
   private val rangeScanKeyProjection: UnsafeProjection = {
-    val refs = rangeScanKeyFieldsWithIdx.map(x =>
+    val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
       BoundReference(x._2, x._1.dataType, x._1.nullable))
     UnsafeProjection.create(refs)
   }
 
   private val remainingKeyProjection: UnsafeProjection = {
-    val refs = remainingKeyFieldsWithIdx.map(x =>
+    val refs = remainingKeyFieldsWithOrdinal.map(x =>
       BoundReference(x._2, x._1.dataType, x._1.nullable))
     UnsafeProjection.create(refs)
   }
 
-  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+  // The original schema that we might get could be:
+  //    [foo, bar, baz, buzz]
+  // We might order by bar and buzz, leading to:
+  //    [bar, buzz, foo, baz]
+  // We need to create a projection that sends, for example, the buzz at index 
1 to index
+  // 3. Thus, for every record in the original schema, we compute where it 
would be in
+  // the joined row and created a projection based on that.
+  private val restoreKeyProjection: UnsafeProjection = {
+    val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
+      val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal)) 
{
+          orderingOrdinals.indexOf(originalOrdinal)
+      } else {
+          orderingOrdinals.length +
+            remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
+      }
+
+      BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
+    }
+    UnsafeProjection.create(refs)
+  }
 
   // Reusable objects
   private val joinedRowOnKey = new JoinedRow()
@@ -307,9 +332,10 @@ class RangeKeyScanStateEncoder(
   // the sorting order on iteration.
   // Also note that the same byte is used to indicate whether the value is 
negative or not.
   private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
-    val writer = new UnsafeRowWriter(numOrderingCols)
+    val writer = new UnsafeRowWriter(orderingOrdinals.length)
     writer.resetRowWriter()
-    rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
       val value = row.get(idx, field.dataType)
       // Note that we cannot allocate a smaller buffer here even if the value 
is null
       // because the effective byte array is considered variable size and 
needs to have
@@ -413,9 +439,11 @@ class RangeKeyScanStateEncoder(
   // actual value.
   // For negative float/double values, we need to flip all the bits back to 
get the original value.
   private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
-    val writer = new UnsafeRowWriter(numOrderingCols)
+    val writer = new UnsafeRowWriter(orderingOrdinals.length)
     writer.resetRowWriter()
-    rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
+
       val value = row.getBinary(idx)
       val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
       bbuf.order(ByteOrder.BIG_ENDIAN)
@@ -464,10 +492,11 @@ class RangeKeyScanStateEncoder(
   }
 
   override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    // This prefix key has the columns specified by orderingOrdinals
     val prefixKey = extractPrefixKey(row)
     val rangeScanKeyEncoded = 
encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
 
-    val result = if (numOrderingCols < keySchema.length) {
+    val result = if (orderingOrdinals.length < keySchema.length) {
       val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
       val encodedBytes = new Array[Byte](rangeScanKeyEncoded.length + 
remainingEncoded.length + 4)
       Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
rangeScanKeyEncoded.length)
@@ -498,10 +527,10 @@ class RangeKeyScanStateEncoder(
       Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
 
     val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded,
-      numFields = numOrderingCols)
+      numFields = orderingOrdinals.length)
     val prefixKeyDecoded = 
decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan)
 
-    if (numOrderingCols < keySchema.length) {
+    if (orderingOrdinals.length < keySchema.length) {
       // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
       val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen
 
@@ -511,9 +540,11 @@ class RangeKeyScanStateEncoder(
         remainingKeyEncodedLen)
 
       val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded,
-        numFields = keySchema.length - numOrderingCols)
+        numFields = keySchema.length - orderingOrdinals.length)
 
-      
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
+      val joined = 
joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)
+      val restored = restoreKeyProjection(joined)
+      restored
     } else {
       // if the number of ordering cols is same as the number of key schema 
cols, we only
       // return the prefix key decoded unsafe row.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d3b3264b8e3d..959cbbaef8b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -301,11 +301,12 @@ case class PrefixKeyScanStateEncoderSpec(
   }
 }
 
+/** Encodes rows so that they can be range-scanned based on orderingOrdinals */
 case class RangeKeyScanStateEncoderSpec(
     keySchema: StructType,
-    numOrderingCols: Int) extends KeyStateEncoderSpec {
-  if (numOrderingCols == 0 || numOrderingCols > keySchema.length) {
-    throw 
StateStoreErrors.incorrectNumOrderingColsForRangeScan(numOrderingCols.toString)
+    orderingOrdinals: Seq[Int]) extends KeyStateEncoderSpec {
+  if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) {
+    throw 
StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString)
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index 16a5935e04f4..f3eb8a392d04 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.util.UUID
 
+import scala.collection.immutable
 import scala.util.Random
 
 import org.apache.hadoop.conf.Configuration
@@ -166,7 +167,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     // zero ordering cols
     val ex1 = intercept[SparkUnsupportedOperationException] {
       tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-        RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 0),
+        RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq()),
         colFamiliesEnabled)) { provider =>
         provider.getStore(0)
       }
@@ -180,10 +181,12 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
       matchPVals = true
     )
 
-    // ordering cols greater than schema cols
+    // ordering ordinals greater than schema cols
     val ex2 = intercept[SparkUnsupportedOperationException] {
       tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-        RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 
keySchemaWithRangeScan.length + 1),
+        RangeKeyScanStateEncoderSpec(
+          keySchemaWithRangeScan,
+          0.to(keySchemaWithRangeScan.length)),
         colFamiliesEnabled)) { provider =>
         provider.getStore(0)
       }
@@ -205,7 +208,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
 
     val ex = intercept[SparkUnsupportedOperationException] {
       tryWithProviderResource(newStoreProvider(keySchemaWithVariableSizeCols,
-        RangeKeyScanStateEncoderSpec(keySchemaWithVariableSizeCols, 1),
+        RangeKeyScanStateEncoderSpec(keySchemaWithVariableSizeCols, Seq(0)),
         colFamiliesEnabled)) { provider =>
         provider.getStore(0)
       }
@@ -221,6 +224,46 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     )
   }
 
+  testWithColumnFamilies("rocksdb range scan validation - variable size data 
types unsupported",
+    TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
+    val keySchemaWithSomeUnsupportedTypeCols: StructType = StructType(Seq(
+      StructField("key1", StringType, false),
+      StructField("key2", IntegerType, false),
+      StructField("key3", FloatType, false),
+      StructField("key4", BinaryType, false)
+    ))
+    val allowedRangeOrdinals = Seq(1, 2)
+
+    keySchemaWithSomeUnsupportedTypeCols.fields.zipWithIndex.foreach { case 
(field, index) =>
+      val isAllowed = allowedRangeOrdinals.contains(index)
+
+      val getStore = () => {
+        
tryWithProviderResource(newStoreProvider(keySchemaWithSomeUnsupportedTypeCols,
+            RangeKeyScanStateEncoderSpec(keySchemaWithSomeUnsupportedTypeCols, 
Seq(index)),
+            colFamiliesEnabled)) { provider =>
+            provider.getStore(0)
+        }
+      }
+
+      if (isAllowed) {
+        getStore()
+      } else {
+        val ex = intercept[SparkUnsupportedOperationException] {
+          getStore()
+        }
+        checkError(
+          ex,
+          errorClass = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED",
+          parameters = Map(
+            "fieldName" -> field.name,
+            "index" -> index.toString
+          ),
+          matchPVals = true
+        )
+      }
+    }
+  }
+
   testWithColumnFamilies("rocksdb range scan validation - null type columns",
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
     val keySchemaWithNullTypeCols: StructType = StructType(
@@ -228,7 +271,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
 
     val ex = intercept[SparkUnsupportedOperationException] {
       tryWithProviderResource(newStoreProvider(keySchemaWithNullTypeCols,
-        RangeKeyScanStateEncoderSpec(keySchemaWithNullTypeCols, 1),
+        RangeKeyScanStateEncoderSpec(keySchemaWithNullTypeCols, Seq(0)),
         colFamiliesEnabled)) { provider =>
         provider.getStore(0)
       }
@@ -248,7 +291,8 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
 
     tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), 
colFamiliesEnabled)) { provider =>
+      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)),
+      colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       // use non-default col family if column families are enabled
@@ -256,7 +300,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           keySchemaWithRangeScan, valueSchema,
-          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1))
+          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)))
       }
 
       val timerTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 
8L,
@@ -305,14 +349,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
 
     val schemaProj = UnsafeProjection.create(Array[DataType](DoubleType, 
StringType))
     tryWithProviderResource(newStoreProvider(testSchema,
-      RangeKeyScanStateEncoderSpec(testSchema, 1), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(testSchema, Seq(0)), colFamiliesEnabled)) { 
provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 1))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0)))
       }
 
       // Verify that the sort ordering here is as follows:
@@ -355,14 +399,15 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
 
     tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), 
colFamiliesEnabled)) { provider =>
+      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)),
+      colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           keySchemaWithRangeScan, valueSchema,
-          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1))
+          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)))
       }
 
       val timerTimestamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 
3L, 35L, 6L, 9L, 5L,
@@ -415,14 +460,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     val schemaProj = UnsafeProjection.create(Array[DataType](LongType, 
IntegerType, StringType))
 
     tryWithProviderResource(newStoreProvider(testSchema,
-      RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), 
colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 2))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)))
       }
 
       val timerTimestamps = Seq((931L, 10), (8000L, 40), (452300L, 1), (4200L, 
68), (90L, 2000),
@@ -447,6 +492,96 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering 
columns",
+    TestWithBothChangelogCheckpointingEnabledAndDisabled ) { 
colFamiliesEnabled =>
+    val testSchema: StructType = StructType(
+      Seq(
+        StructField("ordering-1", LongType, false),
+        StructField("key2", StringType, false),
+        StructField("ordering-2", IntegerType, false),
+        StructField("string-2", StringType, false),
+        StructField("ordering-3", DoubleType, false)
+      )
+    )
+
+    val testSchemaProj = UnsafeProjection.create(Array[DataType](
+        immutable.ArraySeq.unsafeWrapArray(testSchema.fields.map(_.dataType)): 
_*))
+    val rangeScanOrdinals = Seq(0, 2, 4)
+
+    tryWithProviderResource(
+      newStoreProvider(
+        testSchema,
+        RangeKeyScanStateEncoderSpec(testSchema, rangeScanOrdinals),
+        colFamiliesEnabled
+      )
+    ) { provider =>
+      val store = provider.getStore(0)
+
+      val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
+      if (colFamiliesEnabled) {
+        store.createColFamilyIfAbsent(
+          cfName,
+          testSchema,
+          valueSchema,
+          RangeKeyScanStateEncoderSpec(testSchema, rangeScanOrdinals)
+        )
+      }
+
+      val orderedInput = Seq(
+        // Make sure that the first column takes precedence, even if the
+        // later columns are greater
+        (-2L, 0, 99.0),
+        (-1L, 0, 98.0),
+        (0L, 0, 97.0),
+        (2L, 0, 96.0),
+        // Make sure that the second column takes precedence, when the first
+        // column is all the same
+        (3L, -2, -1.0),
+        (3L, -1, -2.0),
+        (3L, 0, -3.0),
+        (3L, 2, -4.0),
+        // Finally, make sure that the third column takes precedence, when the
+        // first two ordering columns are the same.
+        (4L, -1, -127.0),
+        (4L, -1, 0.0),
+        (4L, -1, 64.0),
+        (4L, -1, 127.0)
+      )
+      val scrambledInput = Random.shuffle(orderedInput)
+
+      scrambledInput.foreach { record =>
+        val keyRow = testSchemaProj.apply(
+          new GenericInternalRow(
+            Array[Any](
+              record._1,
+              
UTF8String.fromString(Random.alphanumeric.take(Random.nextInt(20) + 
1).mkString),
+              record._2,
+              
UTF8String.fromString(Random.alphanumeric.take(Random.nextInt(20) + 
1).mkString),
+              record._3
+            )
+          )
+        )
+
+        // The value is just a "dummy" value of 1
+        val valueRow = dataToValueRow(1)
+        store.put(keyRow, valueRow, cfName)
+        assert(valueRowToData(store.get(keyRow, cfName)) === 1)
+      }
+
+      val result = store
+        .iterator(cfName)
+        .map { kv =>
+          val keyRow = kv.key
+          val key = (keyRow.getLong(0), keyRow.getInt(2), keyRow.getDouble(4))
+          (key._1, key._2, key._3)
+        }
+        .toSeq
+
+      assert(result === orderedInput)
+    }
+  }
+
+
   testWithColumnFamilies("rocksdb range scan multiple ordering columns - 
variable size " +
     s"non-ordering columns with null values in first ordering column",
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
@@ -459,14 +594,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     val schemaProj = UnsafeProjection.create(Array[DataType](LongType, 
IntegerType, StringType))
 
     tryWithProviderResource(newStoreProvider(testSchema,
-      RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), 
colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 2))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)))
       }
 
       val timerTimestamps = Seq((931L, 10), (null, 40), (452300L, 1),
@@ -522,7 +657,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
       if (colFamiliesEnabled) {
         store1.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 2))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)))
       }
 
       val timerTimestamps1 = Seq((null, 3), (null, 1), (null, 32),
@@ -559,14 +694,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     val schemaProj = UnsafeProjection.create(Array[DataType](LongType, 
IntegerType, StringType))
 
     tryWithProviderResource(newStoreProvider(testSchema,
-      RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), 
colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 2))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)))
       }
 
       val timerTimestamps = Seq((931L, 10), (40L, null), (452300L, 1),
@@ -612,14 +747,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     val schemaProj = UnsafeProjection.create(Array[DataType](ByteType, 
IntegerType, StringType))
 
     tryWithProviderResource(newStoreProvider(testSchema,
-      RangeKeyScanStateEncoderSpec(testSchema, 2), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)), 
colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           testSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(testSchema, 2))
+          RangeKeyScanStateEncoderSpec(testSchema, Seq(0, 1)))
       }
 
       val timerTimestamps: Seq[(Byte, Int)] = Seq((0x33, 10), (0x1A, 40), 
(0x1F, 1), (0x01, 68),
@@ -649,13 +784,13 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
 
     // use the same schema as value schema for single col key schema
     tryWithProviderResource(newStoreProvider(valueSchema,
-      RangeKeyScanStateEncoderSpec(valueSchema, 1), colFamiliesEnabled)) { 
provider =>
+      RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) 
{ provider =>
       val store = provider.getStore(0)
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           valueSchema, valueSchema,
-          RangeKeyScanStateEncoderSpec(valueSchema, 1))
+          RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)))
       }
 
       val timerTimestamps = Seq(931, 8000, 452300, 4200,
@@ -690,14 +825,15 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
 
     tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1), 
colFamiliesEnabled)) { provider =>
+      RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)),
+      colFamiliesEnabled)) { provider =>
       val store = provider.getStore(0)
 
       val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
       if (colFamiliesEnabled) {
         store.createColFamilyIfAbsent(cfName,
           keySchemaWithRangeScan, valueSchema,
-          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1))
+          RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)))
       }
 
       val timerTimestamps = Seq(931L, -1331L, 8000L, 1L, -244L, -8350L, -55L)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 231396aff222..4523a14ca1cc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -200,7 +200,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   test("running with range scan encoder should fail") {
     val ex = intercept[SparkUnsupportedOperationException] {
       tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan,
-        keyStateEncoderSpec = 
RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, 1),
+        keyStateEncoderSpec = 
RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)),
         useColumnFamilies = false)) { provider =>
         provider.getStore(0)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to