This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 88b29c5076d4 [SPARK-47570][SS] Integrate range scan encoder changes 
with timer implementation
88b29c5076d4 is described below

commit 88b29c5076d48f4ecbed402a693a8ccce57cd7d0
Author: jingz-db <jing.z...@databricks.com>
AuthorDate: Wed Mar 27 13:37:48 2024 +0900

    [SPARK-47570][SS] Integrate range scan encoder changes with timer 
implementation
    
    ### What changes were proposed in this pull request?
    
    Previously timer state implementation was using No prefix rocksdb state 
encoder. When doing `iterator()` or `prefix()`, the returned iterator is not 
sorted on timestamp value. After Anish's PR for supporting range scan encoder, 
we could integrate it with `TimerStateImpl` such that we will use range scan 
encoder on `timer to key`.
    
    ### Why are the changes needed?
    
    The changes are part of the work around adding new stateful streaming 
operator for arbitrary state mgmt that provides a bunch of new features listed 
in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit tests in `TimerSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45709 from jingz-db/integrate-range-scan.
    
    Authored-by: jingz-db <jing.z...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../streaming/StatefulProcessorHandleImpl.scala    |  8 ++-
 .../sql/execution/streaming/TimerStateImpl.scala   | 19 ++++--
 .../streaming/TransformWithStateExec.scala         | 16 ++---
 .../sql/execution/streaming/state/TimerSuite.scala | 69 +++++++++++++++++++---
 4 files changed, 85 insertions(+), 27 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 9b905ad5235d..5f3b794fd117 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -163,12 +163,14 @@ class StatefulProcessorHandleImpl(
   }
 
   /**
-   * Function to retrieve all registered timers for all grouping keys
+   * Function to retrieve all expired registered timers for all grouping keys
+   * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, 
this function
+   *                          will return all timers that have timestamp less 
than passed threshold
    * @return - iterator of registered timers for all grouping keys
    */
-  def getExpiredTimers(): Iterator[(Any, Long)] = {
+  def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
     verifyTimerOperations("get_expired_timers")
-    timerState.getExpiredTimers()
+    timerState.getExpiredTimers(expiryTimestampMs)
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 6166374d25e9..af321eecb4db 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -91,7 +91,7 @@ class TimerStateImpl(
 
   val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
   store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
-    schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex),
+    schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
     useMultipleValuesPerKey = false, isInternal = true)
 
   private def getGroupingKey(cfName: String): Any = {
@@ -110,7 +110,6 @@ class TimerStateImpl(
 
   // We maintain a secondary index that inverts the ordering of the timestamp
   // and grouping key
-  // TODO: use range scan encoder to encode the secondary index key
   private def encodeSecIndexKey(groupingKey: Any, expiryTimestampMs: Long): 
UnsafeRow = {
     val keyByteArr = 
keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
     val keyRow = secIndexKeyEncoder(InternalRow(expiryTimestampMs, keyByteArr))
@@ -187,10 +186,15 @@ class TimerStateImpl(
   }
 
   /**
-   * Function to get all the registered timers for all grouping keys
+   * Function to get all the expired registered timers for all grouping keys.
+   * Perform a range scan on timestamp and will stop iterating once the key 
row timestamp equals or
+   * exceeds the limit (as timestamp key is increasingly sorted).
+   * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, 
this function
+   *                          will return all timers that have timestamp less 
than passed threshold.
    * @return - iterator of all the registered timers for all grouping keys
    */
-  def getExpiredTimers(): Iterator[(Any, Long)] = {
+  def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
+    // this iter is increasingly sorted on timestamp
     val iter = store.iterator(tsToKeyCFName)
 
     new NextIterator[(Any, Long)] {
@@ -199,7 +203,12 @@ class TimerStateImpl(
           val rowPair = iter.next()
           val keyRow = rowPair.key
           val result = getTimerRowFromSecIndex(keyRow)
-          result
+          if (result._2 < expiryTimestampMs) {
+            result
+          } else {
+            finished = true
+            null.asInstanceOf[(Any, Long)]
+          }
         } else {
           finished = true
           null.asInstanceOf[(Any, Long)]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 39365e92185a..d3640ebd8850 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -160,26 +160,18 @@ case class TransformWithStateExec(
       case ProcessingTime =>
         assert(batchTimestampMs.isDefined)
         val batchTimestamp = batchTimestampMs.get
-        val procTimeIter = processorHandle.getExpiredTimers()
-        procTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
-          if (expiryTimestampMs < batchTimestamp) {
+        processorHandle.getExpiredTimers(batchTimestamp)
+          .flatMap { case (keyObj, expiryTimestampMs) =>
             handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
-          } else {
-            Iterator.empty
           }
-        }
 
       case EventTime =>
         assert(eventTimeWatermarkForEviction.isDefined)
         val watermark = eventTimeWatermarkForEviction.get
-        val eventTimeIter = processorHandle.getExpiredTimers()
-        eventTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
-          if (expiryTimestampMs < watermark) {
+        processorHandle.getExpiredTimers(watermark)
+          .flatMap { case (keyObj, expiryTimestampMs) =>
             handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
-          } else {
-            Iterator.empty
           }
-        }
 
       case _ => Iterator.empty
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
index 1aae0e0498aa..1af33aa7b5ad 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -48,7 +48,8 @@ class TimerSuite extends StateVariableSuiteBase {
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
       timerState.registerTimer(1L * 1000)
       assert(timerState.listTimers().toSet === Set(1000L))
-      assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L)))
+      assert(timerState.getExpiredTimers(Long.MaxValue).toSeq === 
Seq(("test_key", 1000L)))
+      assert(timerState.getExpiredTimers(Long.MinValue).toSeq === 
Seq.empty[Long])
 
       timerState.registerTimer(20L * 1000)
       assert(timerState.listTimers().toSet === Set(20000L, 1000L))
@@ -69,8 +70,10 @@ class TimerSuite extends StateVariableSuiteBase {
       timerState1.registerTimer(1L * 1000)
       timerState2.registerTimer(15L * 1000)
       assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
-      assert(timerState1.getExpiredTimers().toSet ===
-        Set(("test_key", 15000L), ("test_key", 1000L)))
+      assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
+        Seq(("test_key", 1000L), ("test_key", 15000L)))
+      // if timestamp equals to expiryTimestampsMs, will not considered expired
+      assert(timerState1.getExpiredTimers(15000L).toSeq === Seq(("test_key", 
1000L)))
       assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
 
       timerState1.registerTimer(20L * 1000)
@@ -99,15 +102,67 @@ class TimerSuite extends StateVariableSuiteBase {
       ImplicitGroupingKeyTracker.removeImplicitKey()
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
-      assert(timerState1.getExpiredTimers().toSet ===
-        Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+      assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
+        Seq(("test_key1", 1000L), ("test_key1", 2000L), ("test_key2", 15000L)))
+      assert(timerState1.getExpiredTimers(10000L).toSeq ===
+        Seq(("test_key1", 1000L), ("test_key1", 2000L)))
       assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
       ImplicitGroupingKeyTracker.removeImplicitKey()
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
       assert(timerState2.listTimers().toSet === Set(15000L))
-      assert(timerState2.getExpiredTimers().toSet ===
-        Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+      assert(timerState2.getExpiredTimers(1500L).toSeq === Seq(("test_key1", 
1000L)))
+    }
+  }
+
+  testWithTimeOutMode("Range scan on second index timer key - " +
+    "verify timestamp is sorted for single instance") { timeoutMode =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      val timerState = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 
3L, 35L, 6L, 9L, 5L)
+      // register/put unordered timestamp into rocksDB
+      timerTimerstamps.foreach(timerState.registerTimer)
+      assert(timerState.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) === 
timerTimerstamps.sorted)
+      assert(timerState.getExpiredTimers(4200L).toSeq.map(_._2) ===
+        timerTimerstamps.sorted.takeWhile(_ < 4200L))
+      assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+    }
+  }
+
+  testWithTimeOutMode("test range scan on second index timer key - " +
+    "verify timestamp is sorted for multiple instances") { timeoutMode =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
+      val timerState1 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L)
+      timerTimestamps1.foreach(timerState1.registerTimer)
+
+      val timerState2 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L)
+      timerTimestamps2.foreach(timerState2.registerTimer)
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key3")
+      val timerState3 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L)
+      timerTimerStamps3.foreach(timerState3.registerTimer)
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+
+      assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) ===
+        (timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted)
+      assert(timerState1.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
+      assert(timerState1.getExpiredTimers(8000L).toSeq.map(_._2) ===
+        (timerTimestamps1 ++ timerTimestamps2 ++ 
timerTimerStamps3).sorted.takeWhile(_ < 8000L))
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to