This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 839bc9f9c264 [SPARK-46913][SS] Add support for processing/event time based timers with transformWithState operator 839bc9f9c264 is described below commit 839bc9f9c264671aa75795d714558c61bd6f64b0 Author: Anish Shrigondekar <anish.shrigonde...@databricks.com> AuthorDate: Thu Mar 14 05:27:28 2024 +0900 [SPARK-46913][SS] Add support for processing/event time based timers with transformWithState operator ### What changes were proposed in this pull request? Add support for processing/event time based timers with `transformWithState` operator ### Why are the changes needed? Changes are required to add event-driven timer based support for stateful streaming applications based on arbitrary state API with the `transformWithState` operator As part of this change - we introduce a bunch of functions that users can use within the `StatefulProcessor` logic. Using the `StatefulProcessorHandle`, users can do the following: - register timer at a given timestamp - delete timer at a given timestamp - list timers Note that all the above operations are tied to the implicit grouping key. In terms of the implementation, we make use of additional column families to support the operations mentioned above. For registered timers, we maintain a primary index (as a col family) that keeps the mapping between the grouping key and expiry timestamp. This col family is used to add and delete timers with direct access to the key and also for listing registered timers for a given grouping key using `prefix scan`. We also maintain a secondary index that inverts the ordering of the t [...] Few additional constraints: - only registered timers are tracked and occupy storage (locally and remotely) - col families starting with `_` are reserved and cannot be used as state variables - timers are checkpointed as before - users have to provide a `timeoutMode` to the operator. Currently, they can choose to not register timeouts or register timeouts that are processing-time based or event-time based. However, this mode has to be declared upfront within the operator arguments. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added unit tests as well as pseudo-integration tests StatefulProcessorHandleSuite ``` 13:58:42.463 WARN org.apache.spark.sql.execution.streaming.state.StatefulProcessorHandleSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.streaming.state.StatefulProcessorHandleSuite, threads: rpc-boss-3-1 (daemon=true), shuffle-boss-6-1 (daemon=true) ===== [info] Run completed in 4 seconds, 559 milliseconds. [info] Total number of tests run: 8 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 8, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` TransformWithStateSuite ``` 13:48:41.858 WARN org.apache.spark.sql.streaming.TransformWithStateSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.streaming.TransformWithStateSuite, threads: QueryStageCreator-0 (daemon=true), state-store-maintenance-thread-0 (daemon=true), ForkJoinPool.commonPool-worker-4 (daemon=true), state-store-maintenance-thread-1 (daemon=true), QueryStageCreator-1 (daemon=true), rpc-boss-3-1 (daemon=true), F orkJoinPool.commonPool-worker-3 (daemon=true), QueryStageCreator-2 (daemon=true), QueryStageCreator-3 (daemon=true), state-store-maintenance-task (daemon=true), ForkJoinPool.com... [info] Run completed in 1 minute, 32 seconds. [info] Total number of tests run: 20 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 20, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45051 from anishshri-db/task/SPARK-46913. Authored-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../src/main/resources/error/error-classes.json | 18 ++ docs/sql-error-conditions.md | 18 ++ .../apache/spark/sql/streaming/TimeoutMode.java | 14 + .../logical/TransformWithStateTimeoutModes.scala | 4 +- .../spark/sql/streaming/ExpiredTimerInfo.scala} | 25 +- .../spark/sql/streaming/StatefulProcessor.scala | 10 +- .../sql/streaming/StatefulProcessorHandle.scala | 23 ++ .../streaming/ExpiredTimerInfoImpl.scala} | 31 +- .../streaming/StatefulProcessorHandleImpl.scala | 80 ++++- .../sql/execution/streaming/TimerStateImpl.scala | 214 +++++++++++++ .../streaming/TransformWithStateExec.scala | 123 +++++++- .../state/HDFSBackedStateStoreProvider.scala | 3 +- .../sql/execution/streaming/state/RocksDB.scala | 21 +- .../state/RocksDBStateStoreProvider.scala | 5 +- .../sql/execution/streaming/state/StateStore.scala | 24 +- .../streaming/state/StateStoreErrors.scala | 49 ++- .../execution/streaming/state/MapStateSuite.scala | 8 +- .../streaming/state/MemoryStateStore.scala | 3 +- .../execution/streaming/state/RocksDBSuite.scala | 37 +++ .../state/StatefulProcessorHandleSuite.scala | 274 ++++++++++++++++ .../streaming/state/ValueStateSuite.scala | 40 ++- .../streaming/TransformWithListStateSuite.scala | 18 +- .../sql/streaming/TransformWithMapStateSuite.scala | 7 +- .../sql/streaming/TransformWithStateSuite.scala | 348 ++++++++++++++++++++- 24 files changed, 1290 insertions(+), 107 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 8272c442ddfa..42c5a107159d 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3491,6 +3491,24 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { + "message" : [ + "Failed to perform stateful processor operation=<operationType> with invalid handle state=<handleState>." + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIMEOUT_MODE" : { + "message" : [ + "Failed to perform stateful processor operation=<operationType> with invalid timeoutMode=<timeoutMode>" + ], + "sqlState" : "42802" + }, + "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { + "message" : [ + "Failed to create column family with unsupported starting character and name=<colFamilyName>." + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_USE_COLUMN_FAMILY_WITH_INVALID_NAME" : { "message" : [ "Failed to perform column family operation=<operationName> with invalid name=<colFamilyName>. Column family name cannot be empty or include leading/trailing spaces or use the reserved keyword=default" diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index dba87bf0136e..b13c8300f4ac 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2148,6 +2148,24 @@ The SQL config `<sqlConf>` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Failed to perform stateful processor operation=`<operationType>` with invalid handle state=`<handleState>`. + +### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIMEOUT_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Failed to perform stateful processor operation=`<operationType>` with invalid timeoutMode=`<timeoutMode>` + +### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Failed to create column family with unsupported starting character and name=`<colFamilyName>`. + ### STATE_STORE_CANNOT_USE_COLUMN_FAMILY_WITH_INVALID_NAME [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java b/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java index d62cdba7fdaa..68b8134cda6c 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java @@ -34,4 +34,18 @@ public class TimeoutMode { public static final TimeoutMode NoTimeouts() { return NoTimeouts$.MODULE$; } + + /** + * Stateful processor that only registers processing time based timers + */ + public static final TimeoutMode ProcessingTime() { + return ProcessingTime$.MODULE$; + } + + /** + * Stateful processor that only registers event time based timers + */ + public static final TimeoutMode EventTime() { + return EventTime$.MODULE$; + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TransformWithStateTimeoutModes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TransformWithStateTimeoutModes.scala index d72678f18571..e420f7821b50 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TransformWithStateTimeoutModes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TransformWithStateTimeoutModes.scala @@ -18,5 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.streaming.TimeoutMode -/** Types of timeouts used in tranformWithState operator */ +/** Types of timeouts used in transformWithState operator */ case object NoTimeouts extends TimeoutMode +case object ProcessingTime extends TimeoutMode +case object EventTime extends TimeoutMode diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala similarity index 62% copy from sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java copy to sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala index d62cdba7fdaa..49dc393f8481 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala @@ -15,23 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming; +package org.apache.spark.sql.streaming -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.catalyst.plans.logical.*; +import java.io.Serializable + +import org.apache.spark.annotation.{Evolving, Experimental} /** - * Represents the type of timeouts possible for the Dataset operations - * {@code transformWithState}. + * Class used to provide access to expired timer's expiry time. These values + * are only relevant if the ExpiredTimerInfo is valid. */ @Experimental @Evolving -public class TimeoutMode { +private[sql] trait ExpiredTimerInfo extends Serializable { + /** + * Check if provided ExpiredTimerInfo is valid. + */ + def isValid(): Boolean + /** - * Stateful processor that does not register timers + * Get the expired timer's expiry time as milliseconds in epoch time. */ - public static final TimeoutMode NoTimeouts() { - return NoTimeouts$.MODULE$; - } + def getExpiryTimeInMs(): Long } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 42a9430bf39d..ad9b807ddf5a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -40,8 +40,11 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { * Function that will be invoked as the first method that allows for users to * initialize all their state variables and perform other init actions before handling data. * @param outputMode - output mode for the stateful processor + * @param timeoutMode - timeout mode for the stateful processor */ - def init(outputMode: OutputMode): Unit + def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key @@ -50,12 +53,15 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { * @param inputRows - iterator of input rows associated with grouping key * @param timerValues - instance of TimerValues that provides access to current processing/event * time if available + * @param expiredTimerInfo - instance of ExpiredTimerInfo that provides access to expired timer + * if applicable * @return - Zero or more output rows */ def handleInputRows( key: K, inputRows: Iterator[I], - timerValues: TimerValues): Iterator[O] + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[O] /** * Function called as the last method that allows for users to perform diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index c26d0d806b86..560188a0ff62 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -70,6 +70,29 @@ private[sql] trait StatefulProcessorHandle extends Serializable { /** Function to return queryInfo for currently running task */ def getQueryInfo(): QueryInfo + /** + * Function to register a processing/event time based timer for given implicit grouping key + * and provided timestamp + * @param expiryTimestampMs - timer expiry timestamp in milliseconds + */ + def registerTimer(expiryTimestampMs: Long): Unit + + /** + * Function to delete a processing/event time based timer for given implicit grouping key + * and provided timestamp + * @param expiryTimestampMs - timer expiry timestamp in milliseconds + */ + def deleteTimer(expiryTimestampMs: Long): Unit + + /** + * Function to list all the timers registered for given implicit grouping key + * Note: calling listTimers() within the `handleInputRows` method of the StatefulProcessor + * will return all the unprocessed registered timers, including the one being fired within the + * invocation of `handleInputRows`. + * @return - list of all the registered timers for given implicit grouping key + */ + def listTimers(): Iterator[Long] + /** * Function to delete and purge state variable if defined previously * @param stateName - name of the state variable diff --git a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala similarity index 51% copy from sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java copy to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala index d62cdba7fdaa..8ab05ef852b8 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala @@ -14,24 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.sql.execution.streaming -package org.apache.spark.sql.streaming; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, TimeoutMode} /** - * Represents the type of timeouts possible for the Dataset operations - * {@code transformWithState}. + * Class that provides a concrete implementation that can be used to provide access to expired + * timer's expiry time. These values are only relevant if the ExpiredTimerInfo + * is valid. + * @param isValid - boolean to check if the provided ExpiredTimerInfo is valid + * @param expiryTimeInMsOpt - option to expired timer's expiry time as milliseconds in epoch time */ -@Experimental -@Evolving -public class TimeoutMode { - /** - * Stateful processor that does not register timers - */ - public static final TimeoutMode NoTimeouts() { - return NoTimeouts$.MODULE$; - } +class ExpiredTimerInfoImpl( + isValid: Boolean, + expiryTimeInMsOpt: Option[Long] = None, + timeoutMode: TimeoutMode = TimeoutMode.NoTimeouts()) extends ExpiredTimerInfo { + + override def isValid(): Boolean = isValid + + override def getExpiryTimeInMs(): Long = expiryTimeInMsOpt.getOrElse(-1L) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index fde8d5c3c1e5..9b905ad5235d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -22,8 +22,9 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, ValueState} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} import org.apache.spark.util.Utils /** @@ -45,7 +46,7 @@ object ImplicitGroupingKeyTracker { */ object StatefulProcessorHandleState extends Enumeration { type StatefulProcessorHandleState = Value - val CREATED, INITIALIZED, DATA_PROCESSED, CLOSED = Value + val CREATED, INITIALIZED, DATA_PROCESSED, TIMER_PROCESSED, CLOSED = Value } class QueryInfoImpl( @@ -76,6 +77,7 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + timeoutMode: TimeoutMode, isStreaming: Boolean = true) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ @@ -114,28 +116,85 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { - verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("get_value_state") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) resultState } override def getQueryInfo(): QueryInfo = currQueryInfo + private lazy val timerState = new TimerStateImpl(store, timeoutMode, keyEncoder) + + private def verifyStateVarOperations(operationType: String): Unit = { + if (currState != CREATED) { + throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, + currState.toString) + } + } + + private def verifyTimerOperations(operationType: String): Unit = { + if (timeoutMode == NoTimeouts) { + throw StateStoreErrors.cannotPerformOperationWithInvalidTimeoutMode(operationType, + timeoutMode.toString) + } + + if (currState < INITIALIZED || currState >= TIMER_PROCESSED) { + throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, + currState.toString) + } + } + + /** + * Function to register a timer for the given expiryTimestampMs + * @param expiryTimestampMs - timestamp in milliseconds for the timer to expire + */ + override def registerTimer(expiryTimestampMs: Long): Unit = { + verifyTimerOperations("register_timer") + timerState.registerTimer(expiryTimestampMs) + } + + /** + * Function to delete a timer for the given expiryTimestampMs + * @param expiryTimestampMs - timestamp in milliseconds for the timer to delete + */ + override def deleteTimer(expiryTimestampMs: Long): Unit = { + verifyTimerOperations("delete_timer") + timerState.deleteTimer(expiryTimestampMs) + } + + /** + * Function to retrieve all registered timers for all grouping keys + * @return - iterator of registered timers for all grouping keys + */ + def getExpiredTimers(): Iterator[(Any, Long)] = { + verifyTimerOperations("get_expired_timers") + timerState.getExpiredTimers() + } + + /** + * Function to list all the registered timers for given implicit key + * Note: calling listTimers() within the `handleInputRows` method of the StatefulProcessor + * will return all the unprocessed registered timers, including the one being fired within the + * invocation of `handleInputRows`. + * @return - iterator of all the registered timers for given implicit key + */ + def listTimers(): Iterator[Long] = { + verifyTimerOperations("list_timers") + timerState.listTimers() + } + /** * Function to delete and purge state variable if defined previously * * @param stateName - name of the state variable */ override def deleteIfExists(stateName: String): Unit = { - verify(currState == CREATED, s"Cannot delete state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("delete_if_exists") store.removeColFamilyIfExists(stateName) } override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("get_list_state") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) resultState } @@ -144,8 +203,7 @@ class StatefulProcessorHandleImpl( stateName: String, userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { - verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + - "initialization is complete") + verifyStateVarOperations("get_map_state") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) resultState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala new file mode 100644 index 000000000000..d8b5cb7ef073 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.types._ +import org.apache.spark.util.NextIterator + +/** + * Singleton utils class used primarily while interacting with TimerState + */ +object TimerStateUtils { + case class TimestampWithKey( + key: Any, + expiryTimestampMs: Long) extends Serializable + + val PROC_TIMERS_STATE_NAME = "_procTimers" + val EVENT_TIMERS_STATE_NAME = "_eventTimers" + val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp" + val TIMESTAMP_TO_KEY_CF = "_timestampToKey" +} + +/** + * Class that provides the implementation for storing timers + * used within the `transformWithState` operator. + * @param store - state store to be used for storing timer data + * @param timeoutMode - mode of timeout (event time or processing time) + * @param keyExprEnc - encoder for key expression + */ +class TimerStateImpl( + store: StateStore, + timeoutMode: TimeoutMode, + keyExprEnc: ExpressionEncoder[Any]) extends Logging { + + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + private val schemaForPrefixKey: StructType = new StructType() + .add("key", BinaryType) + + private val schemaForKeyRow: StructType = new StructType() + .add("key", BinaryType) + .add("expiryTimestampMs", LongType, nullable = false) + + private val keySchemaForSecIndex: StructType = new StructType() + .add("expiryTimestampMs", LongType, nullable = false) + .add("key", BinaryType) + + private val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + + private val keySerializer = keyExprEnc.createSerializer() + + private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey) + + private val keyEncoder = UnsafeProjection.create(schemaForKeyRow) + + private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) + + val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + TimerStateUtils.PROC_TIMERS_STATE_NAME + } else { + TimerStateUtils.EVENT_TIMERS_STATE_NAME + } + + val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + store.createColFamilyIfAbsent(keyToTsCFName, + schemaForKeyRow, numColsPrefixKey = 1, + schemaForValueRow, useMultipleValuesPerKey = false, + isInternal = true) + + val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + store.createColFamilyIfAbsent(tsToKeyCFName, + keySchemaForSecIndex, numColsPrefixKey = 0, + schemaForValueRow, useMultipleValuesPerKey = false, + isInternal = true) + + private def getGroupingKey(cfName: String): Any = { + val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption + if (!keyOption.isDefined) { + throw StateStoreErrors.implicitKeyNotFound(cfName) + } + keyOption.get + } + + private def encodeKey(groupingKey: Any, expiryTimestampMs: Long): UnsafeRow = { + val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() + val keyRow = keyEncoder(InternalRow(keyByteArr, expiryTimestampMs)) + keyRow + } + + // We maintain a secondary index that inverts the ordering of the timestamp + // and grouping key + // TODO: use range scan encoder to encode the secondary index key + private def encodeSecIndexKey(groupingKey: Any, expiryTimestampMs: Long): UnsafeRow = { + val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() + val keyRow = secIndexKeyEncoder(InternalRow(expiryTimestampMs, keyByteArr)) + keyRow + } + + /** + * Function to check if the timer for the given key and timestamp is already registered + * @param expiryTimestampMs - expiry timestamp of the timer + * @return - true if the timer is already registered, false otherwise + */ + private def exists(groupingKey: Any, expiryTimestampMs: Long): Boolean = { + getImpl(groupingKey, expiryTimestampMs) != null + } + + private def getImpl(groupingKey: Any, expiryTimestampMs: Long): UnsafeRow = { + store.get(encodeKey(groupingKey, expiryTimestampMs), keyToTsCFName) + } + + /** + * Function to add a new timer for the given key and timestamp + * @param expiryTimestampMs - expiry timestamp of the timer + */ + def registerTimer(expiryTimestampMs: Long): Unit = { + val groupingKey = getGroupingKey(keyToTsCFName) + if (exists(groupingKey, expiryTimestampMs)) { + logWarning(s"Failed to register timer for key=$groupingKey and " + + s"timestamp=$expiryTimestampMs since it already exists") + } else { + store.put(encodeKey(groupingKey, expiryTimestampMs), EMPTY_ROW, keyToTsCFName) + store.put(encodeSecIndexKey(groupingKey, expiryTimestampMs), EMPTY_ROW, tsToKeyCFName) + logDebug(s"Registered timer for key=$groupingKey and timestamp=$expiryTimestampMs") + } + } + + /** + * Function to remove the timer for the given key and timestamp + * @param expiryTimestampMs - expiry timestamp of the timer + */ + def deleteTimer(expiryTimestampMs: Long): Unit = { + val groupingKey = getGroupingKey(keyToTsCFName) + + if (!exists(groupingKey, expiryTimestampMs)) { + logWarning(s"Failed to delete timer for key=$groupingKey and " + + s"timestamp=$expiryTimestampMs since it does not exist") + } else { + store.remove(encodeKey(groupingKey, expiryTimestampMs), keyToTsCFName) + store.remove(encodeSecIndexKey(groupingKey, expiryTimestampMs), tsToKeyCFName) + logDebug(s"Deleted timer for key=$groupingKey and timestamp=$expiryTimestampMs") + } + } + + def listTimers(): Iterator[Long] = { + val keyByteArr = keySerializer.apply(getGroupingKey(keyToTsCFName)) + .asInstanceOf[UnsafeRow].getBytes() + val keyRow = prefixKeyEncoder(InternalRow(keyByteArr)) + val iter = store.prefixScan(keyRow, keyToTsCFName) + iter.map { kv => + val keyRow = kv.key + keyRow.getLong(1) + } + } + + private def getTimerRowFromSecIndex(keyRow: UnsafeRow): (Any, Long) = { + // Decode the key object from the UnsafeRow + val keyBytes = keyRow.getBinary(1) + val retUnsafeRow = new UnsafeRow(1) + retUnsafeRow.pointTo(keyBytes, keyBytes.length) + val keyObj = keyExprEnc.resolveAndBind() + .createDeserializer().apply(retUnsafeRow).asInstanceOf[Any] + + val expiryTimestampMs = keyRow.getLong(0) + (keyObj, expiryTimestampMs) + } + + /** + * Function to get all the registered timers for all grouping keys + * @return - iterator of all the registered timers for all grouping keys + */ + def getExpiredTimers(): Iterator[(Any, Long)] = { + val iter = store.iterator(tsToKeyCFName) + + new NextIterator[(Any, Long)] { + override protected def getNext(): (Any, Long) = { + if (iter.hasNext) { + val rowPair = iter.next() + val keyRow = rowPair.key + val result = getTimerRowFromSecIndex(keyRow) + result + } else { + finished = true + null.asInstanceOf[(Any, Long)] + } + } + + override protected def close(): Unit = { } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 117bc722f09e..500da5492f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.state._ @@ -69,8 +70,20 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" - // TODO: update this to run no-data batches when timer support is added - override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { + timeoutMode match { + // TODO: check if we can return true only if actual timers are registered + case ProcessingTime => + true + + case EventTime => + eventTimeWatermarkForEviction.isDefined && + newInputWatermark > eventTimeWatermarkForEviction.get + + case _ => + false + } + } override protected def withNewChildInternal( newChild: SparkPlan): TransformWithStateExec = copy(child = newChild) @@ -103,8 +116,11 @@ case class TransformWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects ImplicitGroupingKeyTracker.setImplicitKey(keyObj) val valueObjIter = valueRowIter.map(getValueObj.apply) - val mappedIterator = statefulProcessor.handleInputRows(keyObj, valueObjIter, - new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents)).map { obj => + val mappedIterator = statefulProcessor.handleInputRows( + keyObj, + valueObjIter, + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction), + new ExpiredTimerInfoImpl(isValid = false)).map { obj => getOutputRow(obj) } ImplicitGroupingKeyTracker.removeImplicitKey() @@ -119,6 +135,56 @@ case class TransformWithStateExec( } } + private def handleTimerRows( + keyObj: Any, + expiryTimestampMs: Long, + processorHandle: StatefulProcessorHandleImpl): Iterator[InternalRow] = { + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + ImplicitGroupingKeyTracker.setImplicitKey(keyObj) + val mappedIterator = statefulProcessor.handleInputRows( + keyObj, + Iterator.empty, + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction), + new ExpiredTimerInfoImpl(isValid = true, Some(expiryTimestampMs))).map { obj => + getOutputRow(obj) + } + processorHandle.deleteTimer(expiryTimestampMs) + ImplicitGroupingKeyTracker.removeImplicitKey() + mappedIterator + } + + private def processTimers( + timeoutMode: TimeoutMode, + processorHandle: StatefulProcessorHandleImpl): Iterator[InternalRow] = { + timeoutMode match { + case ProcessingTime => + assert(batchTimestampMs.isDefined) + val batchTimestamp = batchTimestampMs.get + val procTimeIter = processorHandle.getExpiredTimers() + procTimeIter.flatMap { case (keyObj, expiryTimestampMs) => + if (expiryTimestampMs < batchTimestamp) { + handleTimerRows(keyObj, expiryTimestampMs, processorHandle) + } else { + Iterator.empty + } + } + + case EventTime => + assert(eventTimeWatermarkForEviction.isDefined) + val watermark = eventTimeWatermarkForEviction.get + val eventTimeIter = processorHandle.getExpiredTimers() + eventTimeIter.flatMap { case (keyObj, expiryTimestampMs) => + if (expiryTimestampMs < watermark) { + handleTimerRows(keyObj, expiryTimestampMs, processorHandle) + } else { + Iterator.empty + } + } + + case _ => Iterator.empty + } + } + private def processDataWithPartition( iter: Iterator[InternalRow], store: StateStore, @@ -126,9 +192,11 @@ case class TransformWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]] = { val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") val commitTimeMs = longMetric("commitTimeMs") + val timeoutLatencyMs = longMetric("allRemovalsTimeMs") val currentTimeNs = System.nanoTime val updatesStartTimeNs = currentTimeNs + var timeoutProcessingStartTimeNs = currentTimeNs // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForDataForLateEvents match { @@ -138,8 +206,33 @@ case class TransformWithStateExec( iter } - val outputIterator = processNewData(filteredIter) - processorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) + val newDataProcessorIter = + CompletionIterator[InternalRow, Iterator[InternalRow]]( + processNewData(filteredIter), { + // Once the input is processed, mark the start time for timeout processing to measure + // it separately from the overall processing time. + timeoutProcessingStartTimeNs = System.nanoTime + processorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) + }) + + // Late-bind the timeout processing iterator so it is created *after* the input is + // processed (the input iterator is exhausted) and the state updates are written into the + // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store). + val timeoutProcessorIter = new Iterator[InternalRow] { + private lazy val itr = getIterator() + override def hasNext = itr.hasNext + override def next() = itr.next() + private def getIterator(): Iterator[InternalRow] = + CompletionIterator[InternalRow, Iterator[InternalRow]]( + processTimers(timeoutMode, processorHandle), { + // Note: `timeoutLatencyMs` also includes the time the parent operator took for + // processing output returned through iterator. + timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) + processorHandle.setHandleState(StatefulProcessorHandleState.TIMER_PROCESSED) + }) + } + + val outputIterator = newDataProcessorIter ++ timeoutProcessorIter // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { @@ -164,6 +257,20 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, @@ -227,10 +334,10 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode) + statefulProcessor.init(outputMode, timeoutMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index edb95615d588..42d10f6c1bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -125,7 +125,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with keySchema: StructType, numColsPrefixKey: Int, valueSchema: StructType, - useMultipleValuesPerKey: Boolean = false): Unit = { + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 4437cc5583d4..950baba9031b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -221,7 +221,7 @@ class RocksDB( changelogReader = fileManager.getChangelogReader(v, useColumnFamilies) changelogReader.foreach { case (recordType, key, value, colFamilyName) => if (useColumnFamilies && !checkColFamilyExists(colFamilyName)) { - createColFamilyIfAbsent(colFamilyName) + createColFamilyIfAbsent(colFamilyName, checkInternalColumnFamilies(colFamilyName)) } recordType match { @@ -290,7 +290,8 @@ class RocksDB( */ private def verifyColFamilyCreationOrDeletion( operationName: String, - colFamilyName: String): Unit = { + colFamilyName: String, + isInternal: Boolean = false): Unit = { // if the state store instance does not support multiple column families, throw an exception if (!useColumnFamilies) { throw StateStoreErrors.unsupportedOperationException(operationName, @@ -304,13 +305,25 @@ class RocksDB( || colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME) { throw StateStoreErrors.cannotUseColumnFamilyWithInvalidName(operationName, colFamilyName) } + + // if the column family is not internal and uses reserved characters, throw an exception + if (!isInternal && colFamilyName.charAt(0) == '_') { + throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName) + } } + /** + * Check whether the column family name is for internal column families. + * @param cfName - column family name + * @return - true if the column family is for internal use, false otherwise + */ + private def checkInternalColumnFamilies(cfName: String): Boolean = cfName.charAt(0) == '_' + /** * Create RocksDB column family, if not created already */ - def createColFamilyIfAbsent(colFamilyName: String): Unit = { - verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName) + def createColFamilyIfAbsent(colFamilyName: String, isInternal: Boolean = false): Unit = { + verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) if (!checkColFamilyExists(colFamilyName)) { assert(db != null) val descriptor = new ColumnFamilyDescriptor(colFamilyName.getBytes, columnFamilyOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 721d8aa03079..89471f6af535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -53,11 +53,12 @@ private[sql] class RocksDBStateStoreProvider keySchema: StructType, numColsPrefixKey: Int, valueSchema: StructType, - useMultipleValuesPerKey: Boolean = false): Unit = { + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { verify(colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME, s"Failed to create column family with reserved_name=$colFamilyName") verify(useColumnFamilies, "Column families are not supported in this store") - rocksDB.createColFamilyIfAbsent(colFamilyName) + rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) keyValueEncoderMap.putIfAbsent(colFamilyName, (RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey), RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index e2eb0c0728d8..9247c9fe41b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -64,8 +64,9 @@ trait ReadStateStore { * Get the current value of a non-null key. * @return a non-null row if the key exists in the store, otherwise null. */ - def get(key: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow + def get( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow /** * Provides an iterator containing all values of a non-null key. If key does not exist, @@ -89,8 +90,9 @@ trait ReadStateStore { * It is expected to throw exception if Spark calls this method without setting numColsPrefixKey * to the greater than 0. */ - def prefixScan(prefixKey: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] + def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] /** Return an iterator containing all the key-value pairs in the StateStore. */ def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] @@ -128,20 +130,24 @@ trait StateStore extends ReadStateStore { keySchema: StructType, numColsPrefixKey: Int, valueSchema: StructType, - useMultipleValuesPerKey: Boolean = false): Unit + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows * in the params can be reused, and must make copies of the data as needed for persistence. */ - def put(key: UnsafeRow, value: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit + def put( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit /** * Remove a single non-null key. */ - def remove(key: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit + def remove( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit /** * Merges the provided value with existing values of a non-null key. If a existing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 8a0276557f8f..7d1ec7f03237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -32,6 +32,13 @@ object StateStoreErrors { ) } + def missingTimeoutValues(timeoutMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for timeoutMode=$timeoutMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -68,6 +75,23 @@ object StateStoreErrors { errorClass = "ILLEGAL_STATE_STORE_VALUE.EMPTY_LIST_VALUE", messageParameters = Map("stateName" -> stateName)) } + + def cannotCreateColumnFamilyWithReservedChars(colFamilyName: String): + StateStoreCannotCreateColumnFamilyWithReservedChars = { + new StateStoreCannotCreateColumnFamilyWithReservedChars(colFamilyName) + } + + def cannotPerformOperationWithInvalidTimeoutMode( + operationType: String, + timeoutMode: String): StatefulProcessorCannotPerformOperationWithInvalidTimeoutMode = { + new StatefulProcessorCannotPerformOperationWithInvalidTimeoutMode(operationType, timeoutMode) + } + + def cannotPerformOperationWithInvalidHandleState( + operationType: String, + handleState: String): StatefulProcessorCannotPerformOperationWithInvalidHandleState = { + new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) @@ -85,10 +109,33 @@ class StateStoreCannotUseColumnFamilyWithInvalidName(operationName: String, colF errorClass = "STATE_STORE_CANNOT_USE_COLUMN_FAMILY_WITH_INVALID_NAME", messageParameters = Map("operationName" -> operationName, "colFamilyName" -> colFamilyName)) +class StateStoreCannotCreateColumnFamilyWithReservedChars(colFamilyName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", + messageParameters = Map("colFamilyName" -> colFamilyName) + ) + class StateStoreUnsupportedOperationException(operationType: String, entity: String) extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", - messageParameters = Map("operationType" -> operationType, "entity" -> entity)) + messageParameters = Map("operationType" -> operationType, "entity" -> entity) + ) + +class StatefulProcessorCannotPerformOperationWithInvalidTimeoutMode( + operationType: String, + timeoutMode: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIMEOUT_MODE", + messageParameters = Map("operationType" -> operationType, "timeoutMode" -> timeoutMode) + ) + +class StatefulProcessorCannotPerformOperationWithInvalidHandleState( + operationType: String, + handleState: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + messageParameters = Map("operationType" -> operationType, "handleState" -> handleState) + ) class StateStoreUnsupportedOperationOnMissingColumnFamily( operationType: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index d3ed9b5824a1..f7aed2045793 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +73,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +112,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index fa5889891b93..91ffb7a66adc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -35,7 +35,8 @@ class MemoryStateStore extends StateStore() { keySchema: StructType, numColsPrefixKey: Int, valueSchema: StructType, - useMultipleValuesPerKey: Boolean = false): Unit = { + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index a7d4ab362340..87bf172c579a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -573,6 +573,43 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithColumnFamilies(s"RocksDB: column family creation with reserved chars", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() // to make sure that the directory gets created + + val conf = RocksDBConf().copy() + withDB(remoteDir, conf = conf, useColumnFamilies = colFamiliesEnabled) { db => + Seq("_internal", "_test", "_test123", "__12345").foreach { colFamilyName => + val ex = intercept[SparkUnsupportedOperationException] { + db.createColFamilyIfAbsent(colFamilyName) + } + + if (!colFamiliesEnabled) { + checkError( + ex, + errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + parameters = Map( + "operationType" -> "create_col_family", + "entity" -> "multiple column families disabled in RocksDBStateStoreProvider" + ), + matchPVals = true + ) + } else { + checkError( + ex, + errorClass = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", + parameters = Map( + "colFamilyName" -> colFamilyName + ), + matchPVals = true + ) + } + } + } + } + + private def verifyStoreOperationUnsupported( operationName: String, colFamiliesEnabled: Boolean, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala new file mode 100644 index 000000000000..5d9a9cbcaae0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * Class that adds tests to verify operations based on stateful processor handle + * used primarily in queries based on the `transformWithState` operator. + */ +class StatefulProcessorHandleSuite extends SharedSparkSession + with BeforeAndAfter { + + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + after { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + import StateStoreTestsHelper._ + + val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + + val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + + private def keyExprEncoder: ExpressionEncoder[Any] = + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]] + + private def newStoreProviderWithHandle(useColumnFamilies: Boolean): + RocksDBStateStoreProvider = { + newStoreProviderWithHandle(StateStoreId(newDir(), Random.nextInt(), 0), + numColsPrefixKey = 0, + useColumnFamilies = useColumnFamilies) + } + + private def newStoreProviderWithHandle( + storeId: StateStoreId, + numColsPrefixKey: Int, + sqlConf: Option[SQLConf] = None, + conf: Configuration = new Configuration, + useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = { + val provider = new RocksDBStateStoreProvider() + provider.init( + storeId, schemaForKeyRow, schemaForValueRow, numColsPrefixKey = numColsPrefixKey, + useColumnFamilies, + new StateStoreConf(sqlConf.getOrElse(SQLConf.get)), conf) + provider + } + + private def tryWithProviderResource[T]( + provider: StateStoreProvider)(f: StateStoreProvider => T): T = { + try { + f(provider) + } finally { + provider.close() + } + } + + private def getTimeoutMode(timeoutMode: String): TimeoutMode = { + timeoutMode match { + case "NoTimeouts" => TimeoutMode.NoTimeouts() + case "ProcessingTime" => TimeoutMode.ProcessingTime() + case "EventTime" => TimeoutMode.EventTime() + case _ => throw new IllegalArgumentException(s"Invalid timeoutMode=$timeoutMode") + } + } + + Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode => + test(s"value state creation with timeoutMode=$timeoutMode should succeed") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) + handle.getValueState[Long]("testState", Encoders.scalaLong) + } + } + } + + private def verifyInvalidOperation( + handle: StatefulProcessorHandleImpl, + handleState: StatefulProcessorHandleState.Value, + operationType: String)(fn: StatefulProcessorHandleImpl => Unit): Unit = { + handle.setHandleState(handleState) + assert(handle.getHandleState === handleState) + val ex = intercept[SparkUnsupportedOperationException] { + fn(handle) + } + checkError( + ex, + errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + parameters = Map( + "operationType" -> operationType, + "handleState" -> handleState.toString + ), + matchPVals = true + ) + } + + private def createValueStateInstance(handle: StatefulProcessorHandleImpl): Unit = { + handle.getValueState[Long]("testState", Encoders.scalaLong) + } + + private def registerTimer(handle: StatefulProcessorHandleImpl): Unit = { + handle.registerTimer(1000L) + } + + Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode => + test(s"value state creation with timeoutMode=$timeoutMode " + + "and invalid state should fail") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + + Seq(StatefulProcessorHandleState.INITIALIZED, + StatefulProcessorHandleState.DATA_PROCESSED, + StatefulProcessorHandleState.TIMER_PROCESSED, + StatefulProcessorHandleState.CLOSED).foreach { state => + verifyInvalidOperation(handle, state, "get_value_state") { handle => + createValueStateInstance(handle) + } + } + } + } + } + + test("registering processing/event time timeouts with NoTimeout mode should fail") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + val ex = intercept[SparkUnsupportedOperationException] { + handle.registerTimer(10000L) + } + + checkError( + ex, + errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIMEOUT_MODE", + parameters = Map( + "operationType" -> "register_timer", + "timeoutMode" -> TimeoutMode.NoTimeouts().toString + ), + matchPVals = true + ) + + val ex2 = intercept[SparkUnsupportedOperationException] { + handle.deleteTimer(10000L) + } + + checkError( + ex2, + errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIMEOUT_MODE", + parameters = Map( + "operationType" -> "delete_timer", + "timeoutMode" -> TimeoutMode.NoTimeouts().toString + ), + matchPVals = true + ) + } + } + + Seq("ProcessingTime", "EventTime").foreach { timeoutMode => + test(s"registering timeouts with timeoutMode=$timeoutMode should succeed") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isDefined) + + handle.registerTimer(10000L) + handle.deleteTimer(10000L) + + ImplicitGroupingKeyTracker.removeImplicitKey() + assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) + } + } + } + + Seq("ProcessingTime", "EventTime").foreach { timeoutMode => + test(s"verify listing of registered timers with timeoutMode=$timeoutMode") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) + assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key1") + assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isDefined) + + // Generate some random timer timestamps in arbitrary sorted order + val timerTimestamps1 = Seq(931L, 8000L, 452300L, 4200L, 90L, + 1L, 2L, 8L, 3L, 35L, 6L, 9L, 5L) + timerTimestamps1.foreach { timestamp => + handle.registerTimer(timestamp) + } + + val timers1 = handle.listTimers() + assert(timers1.toSeq.sorted === timerTimestamps1.sorted) + ImplicitGroupingKeyTracker.removeImplicitKey() + + ImplicitGroupingKeyTracker.setImplicitKey("test_key2") + assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isDefined) + + // Generate some random timer timestamps in arbitrary sorted order + val timerTimestamps2 = Seq(12000L, 14500L, 16000L) + timerTimestamps2.foreach { timestamp => + handle.registerTimer(timestamp) + } + + val timers2 = handle.listTimers() + assert(timers2.toSeq.sorted === timerTimestamps2.sorted) + ImplicitGroupingKeyTracker.removeImplicitKey() + assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) + } + } + } + + Seq("ProcessingTime", "EventTime").foreach { timeoutMode => + test(s"registering timeouts with timeoutMode=$timeoutMode and invalid state should fail") { + tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + + Seq(StatefulProcessorHandleState.CREATED, + StatefulProcessorHandleState.TIMER_PROCESSED, + StatefulProcessorHandleState.CLOSED).foreach { state => + verifyInvalidOperation(handle, state, "register_timer") { handle => + registerTimer(handle) + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 0bf14037a2ea..e423f9e7385a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -24,12 +24,12 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +48,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -92,7 +92,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +118,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -159,6 +159,28 @@ class ValueStateSuite extends StateVariableSuiteBase { } } + test("Value state operations for unsupported type name should fail") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TimeoutMode.NoTimeouts()) + + val cfName = "_testState" + val ex = intercept[SparkUnsupportedOperationException] { + handle.getValueState[Long](cfName, Encoders.scalaLong) + } + checkError( + ex, + errorClass = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", + parameters = Map( + "colFamilyName" -> cfName + ), + matchPVals = true + ) + } + } + test("colFamily with HDFSBackedStateStoreProvider should fail") { val storeId = StateStoreId(newDir(), Random.nextInt(), 0) val provider = new HDFSBackedStateStoreProvider() @@ -182,7 +204,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -208,7 +230,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -234,7 +256,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -260,7 +282,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 9572f7006f37..95ab34d40131 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -30,14 +30,17 @@ class TestListStateProcessor @transient var _listState: ListState[String] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) } override def handleInputRows( key: String, rows: Iterator[InputRow], - timerValues: TimerValues): Iterator[(String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { var output = List[(String, String)]() @@ -76,8 +79,6 @@ class TestListStateProcessor output.iterator } - - override def close(): Unit = {} } class ToggleSaveAndEmitProcessor @@ -86,7 +87,9 @@ class ToggleSaveAndEmitProcessor @transient var _listState: ListState[String] = _ @transient var _valueState: ValueState[Boolean] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) _valueState = getHandle.getValueState("testValueState", Encoders.scalaBoolean) } @@ -94,7 +97,8 @@ class ToggleSaveAndEmitProcessor override def handleInputRows( key: String, rows: Iterator[String], - timerValues: TimerValues): Iterator[String] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { val valueStateOption = _valueState.getOption() if (valueStateOption.isEmpty || !valueStateOption.get) { @@ -121,8 +125,6 @@ class ToggleSaveAndEmitProcessor } } } - - override def close(): Unit = {} } class TransformWithListStateSuite extends StreamTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index a0576497f399..d7c5ce3815b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -30,14 +30,17 @@ class TestMapStateProcessor @transient var _mapState: MapState[String, String] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { _mapState = getHandle.getMapState("sessionState", Encoders.STRING, Encoders.STRING) } override def handleInputRows( key: String, inputRows: Iterator[InputMapRow], - timerValues: TimerValues): Iterator[(String, String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { var output = List[(String, String, String)]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index cc8c64c94c02..0fd2ef055ffc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.streaming -import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StateStoreMultipleColumnFamiliesNotSupportedException} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateStoreMultipleColumnFamiliesNotSupportedException} +import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 @@ -30,16 +32,19 @@ object TransformWithStateSuiteUtils { class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { - @transient private var _countState: ValueState[Long] = _ + @transient protected var _countState: ValueState[Long] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) } override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues): Iterator[(String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 if (count == 3) { _countState.clear() @@ -49,8 +54,165 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S Iterator((key, count.toString)) } } +} + +// Class to verify stateful processor usage with adding processing time timers +class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { + private def handleProcessingTimeBasedTimers( + key: String, + expiryTimestampMs: Long): Iterator[(String, String)] = { + _countState.clear() + Iterator((key, "-1")) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + + if (expiredTimerInfo.isValid()) { + handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) + } else { + val currCount = _countState.getOption().getOrElse(0L) + if (currCount == 0 && (key == "a" || key == "c")) { + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + + 5000) + } + + val count = currCount + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } + } +} - override def close(): Unit = {} +// Class to verify stateful processor usage with updating processing time timers +class RunningCountStatefulProcessorWithProcTimeTimerUpdates + extends RunningCountStatefulProcessor { + @transient private var _timerState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode) : Unit = { + super.init(outputMode, timeoutMode) + _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) + } + + private def handleProcessingTimeBasedTimers( + key: String, + expiryTimestampMs: Long): Iterator[(String, String)] = { + _timerState.clear() + Iterator((key, "-1")) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + if (expiredTimerInfo.isValid()) { + handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) + } else { + val currCount = _countState.getOption().getOrElse(0L) + val count = currCount + inputRows.size + _countState.update(count) + if (key == "a") { + var nextTimerTs: Long = 0L + if (currCount == 0) { + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } else if (currCount == 1) { + getHandle.deleteTimer(_timerState.get()) + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } + } + Iterator((key, count.toString)) + } + } +} + +class RunningCountStatefulProcessorWithMultipleTimers + extends RunningCountStatefulProcessor { + private def handleProcessingTimeBasedTimers( + key: String, + expiryTimestampMs: Long): Iterator[(String, String)] = { + val currCount = _countState.getOption().getOrElse(0L) + if (getHandle.listTimers().size == 1) { + _countState.clear() + } + Iterator((key, currCount.toString)) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + if (expiredTimerInfo.isValid()) { + handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) + } else { + val currCount = _countState.getOption().getOrElse(0L) + val count = currCount + inputRows.size + _countState.update(count) + if (getHandle.listTimers().isEmpty) { + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000) + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000) + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 15000) + assert(getHandle.listTimers().size == 3) + } + Iterator.empty + } + } +} + +class MaxEventTimeStatefulProcessor + extends StatefulProcessor[String, (String, Long), (String, Int)] + with Logging { + @transient var _maxEventTimeState: ValueState[Long] = _ + @transient var _timerState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { + _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState", + Encoders.scalaLong) + _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Long)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = { + val timeoutDelaySec = 5 + if (expiredTimerInfo.isValid()) { + _maxEventTimeState.clear() + Iterator((key, -1)) + } else { + val valuesSeq = inputRows.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, + _maxEventTimeState.getOption().getOrElse(0L)) + val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000 + _maxEventTimeState.update(maxEventTimeSec) + + val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L) + if (registeredTimerMs < timeoutTimestampMs) { + getHandle.deleteTimer(registeredTimerMs) + getHandle.registerTimer(timeoutTimestampMs) + _timerState.update(timeoutTimestampMs) + } + Iterator((key, maxEventTimeSec.toInt)) + } + } } class RunningCountMostRecentStatefulProcessor @@ -59,14 +221,17 @@ class RunningCountMostRecentStatefulProcessor @transient private var _countState: ValueState[Long] = _ @transient private var _mostRecent: ValueState[String] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues): Iterator[(String, String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 val mostRecent = _mostRecent.getOption().getOrElse("") @@ -78,8 +243,6 @@ class RunningCountMostRecentStatefulProcessor } output.iterator } - - override def close(): Unit = {} } class MostRecentStatefulProcessorWithDeletion @@ -87,7 +250,9 @@ class MostRecentStatefulProcessorWithDeletion with Logging { @transient private var _mostRecent: ValueState[String] = _ - override def init(outputMode: OutputMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode): Unit = { getHandle.deleteIfExists("countState") _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } @@ -95,7 +260,8 @@ class MostRecentStatefulProcessorWithDeletion override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues): Iterator[(String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { val mostRecent = _mostRecent.getOption().getOrElse("") var output = List[(String, String)]() @@ -105,17 +271,17 @@ class MostRecentStatefulProcessorWithDeletion } output.iterator } - - override def close(): Unit = {} } +// Class to verify incorrect usage of stateful processor class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcessor { @transient private var _tempState: ValueState[Long] = _ override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues): Iterator[(String, String)] = { + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { // Trying to create value state here should fail _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong) Iterator.empty @@ -144,8 +310,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest testStream(result, OutputMode.Update())( AddData(inputData, "a"), - ExpectFailure[SparkException] { t => - assert(t.getCause.getMessage.contains("Cannot create state variable")) + ExpectFailure[StatefulProcessorCannotPerformOperationWithInvalidHandleState] { t => + assert(t.getMessage.contains("invalid handle state")) } ) } @@ -180,6 +346,154 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - streaming with rocksdb and processing time timer " + + "should succeed") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeoutMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "b"), + AddData(inputData, "c"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("c", "1")), // should remove 'b' as count reaches 3 + + AddData(inputData, "d"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("c", "-1"), ("d", "1")), + StopStream + ) + } + } + + test("transformWithState - streaming with rocksdb and processing time timer " + + "and updating timers should succeed") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), + TimeoutMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> [6] (= 1 + 5) + + AddData(inputData, "a"), + AdvanceManualClock(2 * 1000), + CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> [9.5] (2 + 7.5) + StopStream, + + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "d"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("d", "1")), // at batch 2, ts = 13, timer for "a" is expired. + // If the timer of "a" was not replaced (pure addition), it would have triggered the timer + // two times here and produced ("a", "-1") two times. + StopStream + ) + } + } + + test("transformWithState - streaming with rocksdb and processing time timer " + + "and multiple timers should succeed") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithMultipleTimers(), + TimeoutMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), // at batch 0, add 3 timers for given key = "a" + + AddData(inputData, "a"), + AdvanceManualClock(6 * 1000), + CheckNewAnswer(("a", "2")), // at ts = 7, first timer expires and produces ("a", "2") + AddData(inputData, "a"), + AdvanceManualClock(5 * 1000), + CheckNewAnswer(("a", "3")), // at ts = 12, second timer expires and produces ("a", "3") + StopStream, + + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(5 * 1000), + CheckNewAnswer(("a", "4")), // at ts = 17, third timer expires and produces ("a", "4") + StopStream + ) + } + } + + test("transformWithState - streaming with rocksdb and event time based timer") { + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS() + .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .transformWithState( + new MaxEventTimeStatefulProcessor(), + TimeoutMode.EventTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + test("Use statefulProcessor without transformWithState - handle should be absent") { val processor = new RunningCountStatefulProcessor() val ex = intercept[Exception] { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org