This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 88b29c5076d4 [SPARK-47570][SS] Integrate range scan encoder changes with timer implementation 88b29c5076d4 is described below commit 88b29c5076d48f4ecbed402a693a8ccce57cd7d0 Author: jingz-db <jing.z...@databricks.com> AuthorDate: Wed Mar 27 13:37:48 2024 +0900 [SPARK-47570][SS] Integrate range scan encoder changes with timer implementation ### What changes were proposed in this pull request? Previously timer state implementation was using No prefix rocksdb state encoder. When doing `iterator()` or `prefix()`, the returned iterator is not sorted on timestamp value. After Anish's PR for supporting range scan encoder, we could integrate it with `TimerStateImpl` such that we will use range scan encoder on `timer to key`. ### Why are the changes needed? The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests in `TimerSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45709 from jingz-db/integrate-range-scan. Authored-by: jingz-db <jing.z...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../streaming/StatefulProcessorHandleImpl.scala | 8 ++- .../sql/execution/streaming/TimerStateImpl.scala | 19 ++++-- .../streaming/TransformWithStateExec.scala | 16 ++--- .../sql/execution/streaming/state/TimerSuite.scala | 69 +++++++++++++++++++--- 4 files changed, 85 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 9b905ad5235d..5f3b794fd117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -163,12 +163,14 @@ class StatefulProcessorHandleImpl( } /** - * Function to retrieve all registered timers for all grouping keys + * Function to retrieve all expired registered timers for all grouping keys + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function + * will return all timers that have timestamp less than passed threshold * @return - iterator of registered timers for all grouping keys */ - def getExpiredTimers(): Iterator[(Any, Long)] = { + def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { verifyTimerOperations("get_expired_timers") - timerState.getExpiredTimers() + timerState.getExpiredTimers(expiryTimestampMs) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 6166374d25e9..af321eecb4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -91,7 +91,7 @@ class TimerStateImpl( val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, - schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex), + schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { @@ -110,7 +110,6 @@ class TimerStateImpl( // We maintain a secondary index that inverts the ordering of the timestamp // and grouping key - // TODO: use range scan encoder to encode the secondary index key private def encodeSecIndexKey(groupingKey: Any, expiryTimestampMs: Long): UnsafeRow = { val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() val keyRow = secIndexKeyEncoder(InternalRow(expiryTimestampMs, keyByteArr)) @@ -187,10 +186,15 @@ class TimerStateImpl( } /** - * Function to get all the registered timers for all grouping keys + * Function to get all the expired registered timers for all grouping keys. + * Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or + * exceeds the limit (as timestamp key is increasingly sorted). + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function + * will return all timers that have timestamp less than passed threshold. * @return - iterator of all the registered timers for all grouping keys */ - def getExpiredTimers(): Iterator[(Any, Long)] = { + def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { + // this iter is increasingly sorted on timestamp val iter = store.iterator(tsToKeyCFName) new NextIterator[(Any, Long)] { @@ -199,7 +203,12 @@ class TimerStateImpl( val rowPair = iter.next() val keyRow = rowPair.key val result = getTimerRowFromSecIndex(keyRow) - result + if (result._2 < expiryTimestampMs) { + result + } else { + finished = true + null.asInstanceOf[(Any, Long)] + } } else { finished = true null.asInstanceOf[(Any, Long)] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 39365e92185a..d3640ebd8850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -160,26 +160,18 @@ case class TransformWithStateExec( case ProcessingTime => assert(batchTimestampMs.isDefined) val batchTimestamp = batchTimestampMs.get - val procTimeIter = processorHandle.getExpiredTimers() - procTimeIter.flatMap { case (keyObj, expiryTimestampMs) => - if (expiryTimestampMs < batchTimestamp) { + processorHandle.getExpiredTimers(batchTimestamp) + .flatMap { case (keyObj, expiryTimestampMs) => handleTimerRows(keyObj, expiryTimestampMs, processorHandle) - } else { - Iterator.empty } - } case EventTime => assert(eventTimeWatermarkForEviction.isDefined) val watermark = eventTimeWatermarkForEviction.get - val eventTimeIter = processorHandle.getExpiredTimers() - eventTimeIter.flatMap { case (keyObj, expiryTimestampMs) => - if (expiryTimestampMs < watermark) { + processorHandle.getExpiredTimers(watermark) + .flatMap { case (keyObj, expiryTimestampMs) => handleTimerRows(keyObj, expiryTimestampMs, processorHandle) - } else { - Iterator.empty } - } case _ => Iterator.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index 1aae0e0498aa..1af33aa7b5ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -48,7 +48,8 @@ class TimerSuite extends StateVariableSuiteBase { Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) timerState.registerTimer(1L * 1000) assert(timerState.listTimers().toSet === Set(1000L)) - assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L))) + assert(timerState.getExpiredTimers(Long.MaxValue).toSeq === Seq(("test_key", 1000L))) + assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty[Long]) timerState.registerTimer(20L * 1000) assert(timerState.listTimers().toSet === Set(20000L, 1000L)) @@ -69,8 +70,10 @@ class TimerSuite extends StateVariableSuiteBase { timerState1.registerTimer(1L * 1000) timerState2.registerTimer(15L * 1000) assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) - assert(timerState1.getExpiredTimers().toSet === - Set(("test_key", 15000L), ("test_key", 1000L))) + assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq === + Seq(("test_key", 1000L), ("test_key", 15000L))) + // if timestamp equals to expiryTimestampsMs, will not considered expired + assert(timerState1.getExpiredTimers(15000L).toSeq === Seq(("test_key", 1000L))) assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) timerState1.registerTimer(20L * 1000) @@ -99,15 +102,67 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.removeImplicitKey() ImplicitGroupingKeyTracker.setImplicitKey("test_key1") - assert(timerState1.getExpiredTimers().toSet === - Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L))) + assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq === + Seq(("test_key1", 1000L), ("test_key1", 2000L), ("test_key2", 15000L))) + assert(timerState1.getExpiredTimers(10000L).toSeq === + Seq(("test_key1", 1000L), ("test_key1", 2000L))) assert(timerState1.listTimers().toSet === Set(1000L, 2000L)) ImplicitGroupingKeyTracker.removeImplicitKey() ImplicitGroupingKeyTracker.setImplicitKey("test_key2") assert(timerState2.listTimers().toSet === Set(15000L)) - assert(timerState2.getExpiredTimers().toSet === - Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L))) + assert(timerState2.getExpiredTimers(1500L).toSeq === Seq(("test_key1", 1000L))) + } + } + + testWithTimeOutMode("Range scan on second index timer key - " + + "verify timestamp is sorted for single instance") { timeoutMode => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + val timerState = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 3L, 35L, 6L, 9L, 5L) + // register/put unordered timestamp into rocksDB + timerTimerstamps.foreach(timerState.registerTimer) + assert(timerState.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) === timerTimerstamps.sorted) + assert(timerState.getExpiredTimers(4200L).toSeq.map(_._2) === + timerTimerstamps.sorted.takeWhile(_ < 4200L)) + assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty) + ImplicitGroupingKeyTracker.removeImplicitKey() + } + } + + testWithTimeOutMode("test range scan on second index timer key - " + + "verify timestamp is sorted for multiple instances") { timeoutMode => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key1") + val timerState1 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L) + timerTimestamps1.foreach(timerState1.registerTimer) + + val timerState2 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L) + timerTimestamps2.foreach(timerState2.registerTimer) + ImplicitGroupingKeyTracker.removeImplicitKey() + + ImplicitGroupingKeyTracker.setImplicitKey("test_key3") + val timerState3 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L) + timerTimerStamps3.foreach(timerState3.registerTimer) + ImplicitGroupingKeyTracker.removeImplicitKey() + + assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) === + (timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted) + assert(timerState1.getExpiredTimers(Long.MinValue).toSeq === Seq.empty) + assert(timerState1.getExpiredTimers(8000L).toSeq.map(_._2) === + (timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted.takeWhile(_ < 8000L)) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org