This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d55bb617a135 [SPARK-47558][SS] State TTL support for ValueState d55bb617a135 is described below commit d55bb617a13561f0eb9f301089a4e4fb06e06228 Author: Bhuwan Sahni <bhuwan.sa...@databricks.com> AuthorDate: Mon Apr 8 12:22:04 2024 +0900 [SPARK-47558][SS] State TTL support for ValueState **Note**: This change has been co-authored by ericm-db and sahnib **Authors: ericm-db sahnib** ### What changes were proposed in this pull request? This PR adds support for expiring state based on TTL for ValueState. Using this functionality, Spark users can specify a TTL Mode for transformWithState operator, and provide a ttlDuration/expirationTImeInMs for each value in ValueState. TTL support for List/Map State will be added in future PRs. Once the ttlDuration has expired, the value will not be returned as part of `get()` and would be cleaned up at the end of the micro-batch. ### Why are the changes needed? These changes are needed to support TTL for ValueState. The PR supports specifying ttl for processing time or event time. Processing time ttl is calculated by adding ttlDuration to `batchTimestamp`, and event time ttl is specified using absolute expiration time (`expirationTimeInMs`). ### Does this PR introduce _any_ user-facing change? Yes, modifies the ValueState interface for specifying `ttlDuration`, and adds `ttlMode` to `transformWithState` API. ### How was this patch tested? Added unit test cases for both event time and processing time in `ValueStateWithTTLSuite`. ``` WARNING: Using incubator modules: jdk.incubator.foreign, jdk.incubator.vector [info] TransformWithStateTTLSuite: 11:56:54.590 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 11:56:56.054 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate state is evicted at ttl expiry - processing time ttl (6 seconds, 244 milliseconds) 11:57:01.188 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate ttl update updates the expiration timestamp - processing time ttl (4 seconds, 465 milliseconds) 11:57:05.641 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate ttl removal keeps value in state - processing time ttl (4 seconds, 407 milliseconds) 11:57:10.041 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate multiple value states - with and without ttl - processing time ttl (3 seconds, 131 milliseconds) 11:57:13.175 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate state is evicted at ttl expiry - event time ttl (4 seconds, 186 milliseconds) 11:57:17.355 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate ttl update updates the expiration timestamp - event time ttl (4 seconds, 28 milliseconds) 11:57:21.391 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - validate ttl removal keeps value in state - event time ttl (4 seconds, 428 milliseconds) 11:57:25.838 WARN org.apache.spark.sql.streaming.TransformWithStateTTLSuite: [info] Run completed in 32 seconds, 433 milliseconds. [info] Total number of tests run: 7 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 7, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45674 from sahnib/state-ttl. Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../src/main/resources/error/error-classes.json | 17 + .../apache/spark/sql/KeyValueGroupedDataset.scala | 14 +- dev/checkstyle-suppressions.xml | 2 + ...r-conditions-unsupported-feature-error-class.md | 4 + docs/sql-error-conditions.md | 12 + .../org/apache/spark/sql/streaming/TTLMode.java} | 40 +- .../plans/logical/TTLMode.scala} | 36 +- .../spark/sql/streaming/StatefulProcessor.scala | 3 +- .../sql/streaming/StatefulProcessorHandle.scala | 25 +- .../{ValueState.scala => TTLConfig.scala} | 36 +- .../apache/spark/sql/streaming/ValueState.scala | 6 +- .../spark/sql/catalyst/plans/logical/object.scala | 7 +- .../apache/spark/sql/KeyValueGroupedDataset.scala | 24 +- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../sql/execution/streaming/ListStateImpl.scala | 2 +- .../sql/execution/streaming/MapStateImpl.scala | 2 +- .../streaming/StateTypesEncoderUtils.scala | 84 +++- .../streaming/StatefulProcessorHandleImpl.scala | 63 ++- .../spark/sql/execution/streaming/TTLState.scala | 153 +++++++ .../sql/execution/streaming/TimerStateImpl.scala | 8 +- .../streaming/TransformWithStateExec.scala | 105 +++-- .../sql/execution/streaming/ValueStateImpl.scala | 33 +- .../streaming/ValueStateImplWithTTL.scala | 184 ++++++++ .../streaming/state/StateStoreErrors.scala | 29 ++ .../org/apache/spark/sql/JavaDatasetSuite.java | 2 + .../apache/spark/sql/TestStatefulProcessor.java | 5 +- .../sql/TestStatefulProcessorWithInitialState.java | 5 +- .../execution/streaming/state/ListStateSuite.scala | 16 +- .../execution/streaming/state/MapStateSuite.scala | 11 +- .../state/StatefulProcessorHandleSuite.scala | 46 +- .../streaming/state/ValueStateSuite.scala | 117 ++++- .../streaming/TransformWithListStateSuite.scala | 14 +- .../sql/streaming/TransformWithMapStateSuite.scala | 8 +- .../TransformWithStateInitialStateSuite.scala | 19 +- .../sql/streaming/TransformWithStateSuite.scala | 34 +- .../TransformWithValueStateTTLSuite.scala | 471 +++++++++++++++++++++ 36 files changed, 1407 insertions(+), 237 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index aeb35b864c66..f28adaf40230 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3579,6 +3579,12 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : { + "message" : [ + "Cannot use TTL for state=<stateName> in NoTTL() mode." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation=<operationType> with invalid handle state=<handleState>." @@ -3597,6 +3603,12 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE" : { + "message" : [ + "TTL duration must be greater than zero for State store operation=<operationType> on state=<stateName>." + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=<colFamilyName>." @@ -4391,6 +4403,11 @@ "Removing column families with <stateStoreProvider> is not supported." ] }, + "STATE_STORE_TTL" : { + "message" : [ + "State TTL with <stateStoreProvider> is not supported. Please use RocksDBStateStoreProvider." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table <tableName> does not support <operation>. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 1b712348d865..39e0c429046d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -829,12 +829,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. */ def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode): Dataset[U] = { throw new UnsupportedOperationException } @@ -853,6 +856,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. * @param outputEncoder @@ -861,6 +866,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, outputEncoder: Encoder[U]): Dataset[U] = { throw new UnsupportedOperationException @@ -879,6 +885,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration. * @param outputMode * The output mode of the stateful processor. * @param initialState @@ -890,6 +898,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { throw new UnsupportedOperationException @@ -908,6 +917,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * Instance of statefulProcessor whose functions will be invoked by the operator. * @param timeoutMode * The timeout mode of the stateful processor. + * @param ttlMode + * The ttlMode to evict user state on ttl expiration * @param outputMode * The output mode of the stateful processor. * @param initialState @@ -923,6 +934,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S], outputEncoder: Encoder[U], diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 7b20dfb6bce5..94dfe20af56e 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -60,6 +60,8 @@ files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/> <suppress checks="MethodName" files="sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java"/> + <suppress checks="MethodName" + files="sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java"/> <suppress checks="LineLength" files="src/main/java/org/apache/spark/sql/api/java/*"/> <suppress checks="IllegalImport" diff --git a/docs/sql-error-conditions-unsupported-feature-error-class.md b/docs/sql-error-conditions-unsupported-feature-error-class.md index e580ecc63b18..f67d7caff63d 100644 --- a/docs/sql-error-conditions-unsupported-feature-error-class.md +++ b/docs/sql-error-conditions-unsupported-feature-error-class.md @@ -202,6 +202,10 @@ Creating multiple column families with `<stateStoreProvider>` is not supported. Removing column families with `<stateStoreProvider>` is not supported. +## STATE_STORE_TTL + +State TTL with `<stateStoreProvider>` is not supported. Please use RocksDBStateStoreProvider. + ## TABLE_OPERATION Table `<tableName>` does not support `<operation>`. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index ee3a3bd07a77..d8261b8c2765 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2183,6 +2183,12 @@ The SQL config `<sqlConf>` cannot be found. Please verify that the config exists Star (*) is not allowed in a select list when GROUP BY an ordinal position is used. +### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot use TTL for state=`<stateName>` in NoTTL() mode. + ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -2201,6 +2207,12 @@ Failed to perform stateful processor operation=`<operationType>` with invalid ti Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=`<groupingKey>`. +### STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +TTL duration must be greater than zero for State store operation=`<operationType>` on state=`<stateName>`. + ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java similarity index 53% copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala copy to sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java index 9c707c8308ab..30594770b3e1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java @@ -15,36 +15,28 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.streaming; -import java.io.Serializable +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.catalyst.plans.logical.*; -import org.apache.spark.annotation.{Evolving, Experimental} - -@Experimental -@Evolving /** - * Interface used for arbitrary stateful operations with the v2 API to capture - * single value state. + * Represents the type of ttl modes possible for the Dataset operations + * {@code transformWithState}. */ -private[sql] trait ValueState[S] extends Serializable { - - /** Whether state exists or not. */ - def exists(): Boolean +@Experimental +@Evolving +public class TTLMode { /** - * Get the state value if it exists - * @throws java.util.NoSuchElementException if the state does not exist + * Specifies that there is no TTL for the user state. User state would not + * be cleaned up by Spark automatically. */ - @throws[NoSuchElementException] - def get(): S - - /** Get the state if it exists as an option and None otherwise */ - def getOption(): Option[S] - - /** Update the value of the state. */ - def update(newState: S): Unit + public static final TTLMode NoTTL() { return NoTTL$.MODULE$; } - /** Remove this state. */ - def clear(): Unit + /** + * Specifies that all ttl durations for user state are in processing time. + */ + public static final TTLMode ProcessingTimeTTL() { return ProcessingTimeTTL$.MODULE$; } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala similarity index 50% copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala copy to sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala index 9c707c8308ab..be4794a5f40b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala @@ -14,37 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.sql.catalyst.plans.logical -package org.apache.spark.sql.streaming +import org.apache.spark.sql.streaming.TTLMode -import java.io.Serializable +/** TTL types used in tranformWithState operator */ +case object NoTTL extends TTLMode -import org.apache.spark.annotation.{Evolving, Experimental} - -@Experimental -@Evolving -/** - * Interface used for arbitrary stateful operations with the v2 API to capture - * single value state. - */ -private[sql] trait ValueState[S] extends Serializable { - - /** Whether state exists or not. */ - def exists(): Boolean - - /** - * Get the state value if it exists - * @throws java.util.NoSuchElementException if the state does not exist - */ - @throws[NoSuchElementException] - def get(): S - - /** Get the state if it exists as an option and None otherwise */ - def getOption(): Option[S] - - /** Update the value of the state. */ - def update(newState: S): Unit - - /** Remove this state. */ - def clear(): Unit -} +case object ProcessingTimeTTL extends TTLMode diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 42d12dd91e94..70f9cdfa399a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -44,7 +44,8 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { */ def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 560188a0ff62..e65667206ded 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -30,16 +30,37 @@ import org.apache.spark.sql.Encoder private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type + * Function to create new or return existing single value state variable of given type. * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. - * @param stateName - name of the state variable + * + * @param stateName - name of the state variable * @param valEncoder - SQL encoder for state variable * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] + /** + * Function to create new or return existing single value state variable of given type + * with ttl. State values will not be returned past ttlDuration, and will be eventually removed + * from the state store. Any state update resets the ttl to current processing time plus + * ttlDuration. + * + * The user must ensure to call this function only within the `init()` method of the + * StatefulProcessor. + * + * @param stateName - name of the state variable + * @param valEncoder - SQL encoder for state variable + * @param ttlConfig - the ttl configuration (time to live duration etc.) + * @tparam T - type of state variable + * @return - instance of ValueState of type T that can be used to store state persistently + */ + def getValueState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig): ValueState[T] + /** * Creates new or returns existing list state associated with stateName. * The ListState persists values of type T. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala similarity index 53% copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala copy to sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala index 9c707c8308ab..576e09d5d7fe 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala @@ -17,34 +17,14 @@ package org.apache.spark.sql.streaming -import java.io.Serializable +import java.time.Duration -import org.apache.spark.annotation.{Evolving, Experimental} - -@Experimental -@Evolving /** - * Interface used for arbitrary stateful operations with the v2 API to capture - * single value state. + * TTL Configuration for state variable. State values will not be returned past ttlDuration, + * and will be eventually removed from the state store. Any state update resets the ttl to + * current processing time plus ttlDuration. + * + * @param ttlDuration time to live duration for state + * stored in the state variable. */ -private[sql] trait ValueState[S] extends Serializable { - - /** Whether state exists or not. */ - def exists(): Boolean - - /** - * Get the state value if it exists - * @throws java.util.NoSuchElementException if the state does not exist - */ - @throws[NoSuchElementException] - def get(): S - - /** Get the state if it exists as an option and None otherwise */ - def getOption(): Option[S] - - /** Update the value of the state. */ - def update(newState: S): Unit - - /** Remove this state. */ - def clear(): Unit -} +case class TTLConfig(ttlDuration: Duration) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 9c707c8308ab..8a2661e1a55b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -42,7 +42,11 @@ private[sql] trait ValueState[S] extends Serializable { /** Get the state if it exists as an option and None otherwise */ def getOption(): Option[S] - /** Update the value of the state. */ + /** + * Update the value of the state. + * + * @param newState the new value + */ def update(newState: S): Unit /** Remove this state. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index b2c443a8cce0..ff7c8fb3df4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode, TTLMode} import org.apache.spark.sql.types._ object CatalystSerde { @@ -574,6 +574,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { @@ -584,6 +585,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -605,6 +607,7 @@ object TransformWithState { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[K, V, U], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan, @@ -618,6 +621,7 @@ object TransformWithState { groupingAttributes, dataAttributes, statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], + ttlMode, timeoutMode, outputMode, keyEncoder.asInstanceOf[ExpressionEncoder[Any]], @@ -639,6 +643,7 @@ case class TransformWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 55ac3daa6209..f3713edd0ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, TTLMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -652,16 +652,18 @@ class KeyValueGroupedDataset[K, V] private[sql]( * invocations. * * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the - * operator. - * @param timeoutMode The timeout mode of the stateful processor. - * @param outputMode The output mode of the stateful processor. + * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked + * by the operator. + * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration + * @param outputMode The output mode of the stateful processor. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. */ private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode): Dataset[U] = { Dataset[U]( sparkSession, @@ -669,6 +671,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan @@ -689,6 +692,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the * operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. * @param outputEncoder Encoder for the output type. * @@ -697,9 +701,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, outputEncoder: Encoder[U]): Dataset[U] = { - transformWithState(statefulProcessor, timeoutMode, outputMode)(outputEncoder) + transformWithState(statefulProcessor, timeoutMode, ttlMode, outputMode)(outputEncoder) } /** @@ -712,6 +717,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will * be invoked by the operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. * @param initialState User provided initial state that will be used to initiate state for * the query in the first batch. @@ -721,6 +727,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { Dataset[U]( @@ -729,6 +736,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, child = logicalPlan, @@ -749,6 +757,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param statefulProcessor Instance of statefulProcessor whose functions will * be invoked by the operator. * @param timeoutMode The timeout mode of the stateful processor. + * @param ttlMode The ttlMode to evict user state on ttl expiration * @param outputMode The output mode of the stateful processor. * @param initialState User provided initial state that will be used to initiate state for * the query in the first batch. @@ -760,11 +769,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, + ttlMode: TTLMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S], outputEncoder: Encoder[U], initialStateEncoder: Encoder[S]): Dataset[U] = { - transformWithState(statefulProcessor, timeoutMode, + transformWithState(statefulProcessor, timeoutMode, ttlMode, outputMode, initialState)(outputEncoder, initialStateEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cc212d99f299..2c534eb36f9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -751,7 +751,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => @@ -761,6 +761,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -925,12 +926,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, child, hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, initialState) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + groupingAttributes, dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, keyEncoder, outputObjAttr, planLater(child), hasInitialState, initialStateGroupingAttrs, initialStateDataAttrs, initialStateDeserializer, planLater(initialState)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 662bef5716ea..56c9d2664d9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index d2ccd0a77807..c58f32ed756d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -45,7 +45,7 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty + store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1d41db896cdf..b2dba7668d62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -23,11 +23,15 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, LongType, StructType} -object StateKeyValueRowSchema { +object TransformWithStateKeyValueRowSchema { val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType) - val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType) + val VALUE_ROW_SCHEMA: StructType = new StructType() + .add("value", BinaryType) + val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType() + .add("value", BinaryType) + .add("ttlExpirationMs", LongType) } /** @@ -49,12 +53,17 @@ object StateKeyValueRowSchema { class StateTypesEncoder[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String) { - import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._ + stateName: String, + hasTtl: Boolean) { + import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._ /** Variables reused for conversions between byte array and UnsafeRow */ private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) - private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA) + private val valueProjection = if (hasTtl) { + UnsafeProjection.create(VALUE_ROW_SCHEMA_WITH_TTL) + } else { + UnsafeProjection.create(VALUE_ROW_SCHEMA) + } /** Variables reused for value conversions between spark sql and object */ private val valExpressionEnc = encoderFor(valEncoder) @@ -65,22 +74,47 @@ class StateTypesEncoder[GK, V]( // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { + val keyRow = keyProjection(InternalRow(serializeGroupingKey())) + keyRow + } + + /** + * Encodes the provided grouping key into Spark UnsafeRow. + * + * @param groupingKeyBytes serialized grouping key byte array + * @return encoded UnsafeRow + */ + def encodeSerializedGroupingKey(groupingKeyBytes: Array[Byte]): UnsafeRow = { + val keyRow = keyProjection(InternalRow(groupingKeyBytes)) + keyRow + } + + def serializeGroupingKey(): Array[Byte] = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(stateName) } - val groupingKey = keyOption.get.asInstanceOf[GK] - val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() - val keyRow = keyProjection(InternalRow(keyByteArr)) - keyRow + keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes() } + /** + * Encode the specified value in Spark UnsafeRow with no ttl. + */ def encodeValue(value: V): UnsafeRow = { val objRow: InternalRow = objToRowSerializer.apply(value) val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() - val valRow = valueProjection(InternalRow(bytes)) - valRow + valueProjection(InternalRow(bytes)) + } + + /** + * Encode the specified value in Spark UnsafeRow + * with provided ttl expiration. + */ + def encodeValue(value: V, expirationMs: Long): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + valueProjection(InternalRow(bytes, expirationMs)) } def decodeValue(row: UnsafeRow): V = { @@ -89,14 +123,31 @@ class StateTypesEncoder[GK, V]( val value = rowToObjDeserializer.apply(reusedValRow) value } + + /** + * Decode the ttl information out of Value row. If the ttl has + * not been set (-1L specifies no user defined value), the API will + * return None. + */ + def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = { + // ensure ttl has been set + assert(hasTtl) + val expirationMs = row.getLong(1) + if (expirationMs == -1) { + None + } else { + Some(expirationMs) + } + } } object StateTypesEncoder { def apply[GK, V]( keySerializer: Serializer[GK], valEncoder: Encoder[V], - stateName: String): StateTypesEncoder[GK, V] = { - new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) + stateName: String, + hasTtl: Boolean = false): StateTypesEncoder[GK, V] = { + new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) } } @@ -105,8 +156,9 @@ class CompositeKeyStateEncoder[GK, K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V], schemaForCompositeKeyRow: StructType, - stateName: String) - extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) { + stateName: String, + hasTtl: Boolean = false) + extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl) { private val compositeKeyProjection = UnsafeProjection.create(schemaForCompositeKeyRow) private val reusedKeyRow = new UnsafeRow(userKeyEnc.schema.fields.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 5f3b794fd117..7bef62b7fcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util import java.util.UUID import org.apache.spark.TaskContext @@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TimeoutMode, TTLConfig, TTLMode, ValueState} import org.apache.spark.util.Utils /** @@ -77,14 +78,22 @@ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, keyEncoder: ExpressionEncoder[Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, - isStreaming: Boolean = true) + isStreaming: Boolean = true, + batchTimestampMs: Option[Long] = None) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + /** + * Stores all the active ttl states, and is used to cleanup expired values + * in [[doTtlCleanup()]] function. + */ + private[sql] val ttlStates: util.List[TTLState] = new util.ArrayList[TTLState]() + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" - private def buildQueryInfo(): QueryInfo = { + private def buildQueryInfo(): QueryInfo = { val taskCtxOpt = Option(TaskContext.get()) val (queryId, batchId) = if (!isStreaming) { (BATCH_QUERY_ID, 0L) @@ -103,22 +112,33 @@ class StatefulProcessorHandleImpl( private var currState: StatefulProcessorHandleState = CREATED - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } - def setHandleState(newState: StatefulProcessorHandleState): Unit = { currState = newState } def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) - resultState + + new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + } + + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig): ValueState[T] = { + verifyStateVarOperations("get_value_state") + validateTTLConfig(ttlConfig, stateName) + + assert(batchTimestampMs.isDefined) + val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) + ttlStates.add(valueStateWithTTL) + + valueStateWithTTL } override def getQueryInfo(): QueryInfo = currQueryInfo @@ -185,6 +205,16 @@ class StatefulProcessorHandleImpl( timerState.listTimers() } + /** + * Performs the user state cleanup based on assigned TTl values. Any state + * which is expired will be cleaned up from StateStore. + */ + def doTtlCleanup(): Unit = { + ttlStates.forEach { s => + s.clearExpiredState() + } + } + /** * Function to delete and purge state variable if defined previously * @@ -209,4 +239,13 @@ class StatefulProcessorHandleImpl( val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) resultState } + + private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = { + val ttlDuration = ttlConfig.ttlDuration + if (ttlMode != TTLMode.ProcessingTimeTTL()) { + throw StateStoreErrors.cannotProvideTTLConfigForNoTTLMode(stateName) + } else if (ttlDuration == null || ttlDuration.isNegative || ttlDuration.isZero) { + throw StateStoreErrors.ttlMustBePositive("update", stateName) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala new file mode 100644 index 000000000000..0ae93549b731 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.time.Duration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, StateStore} +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, StructField, StructType} + +object StateTTLSchema { + val TTL_KEY_ROW_SCHEMA: StructType = new StructType() + .add("expirationMs", LongType) + .add("groupingKey", BinaryType) + val TTL_VALUE_ROW_SCHEMA: StructType = + StructType(Array(StructField("__dummy__", NullType))) +} + +/** + * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]]. + * + * @param groupingKey grouping key for which ttl is set + * @param expirationMs expiration time for the grouping key + */ +case class SingleKeyTTLRow( + groupingKey: Array[Byte], + expirationMs: Long) + +/** + * Represents the underlying state for secondary TTL Index for a user defined + * state variable. + * + * This state allows Spark to query ttl values based on expiration time + * allowing efficient ttl cleanup. + */ +trait TTLState { + + /** + * Perform the user state clean up based on ttl values stored in + * this state. NOTE that its not safe to call this operation concurrently + * when the user can also modify the underlying State. Cleanup should be initiated + * after arbitrary state operations are completed by the user. + */ + def clearExpiredState(): Unit + + /** + * Clears the user state associated with this grouping key + * if it has expired. This function is called by Spark to perform + * cleanup at the end of transformWithState processing. + * + * Spark uses a secondary index to determine if the user state for + * this grouping key has expired. However, its possible that the user + * has updated the TTL and secondary index is out of date. Implementations + * must validate that the user State has actually expired before cleanup based + * on their own State data. + * + * @param groupingKey grouping key for which cleanup should be performed. + */ + def clearIfExpired(groupingKey: Array[Byte]): Unit +} + +/** + * Manages the ttl information for user state keyed with a single key (grouping key). + */ +abstract class SingleKeyTTLStateImpl( + stateName: String, + store: StateStore, + ttlExpirationMs: Long) + extends TTLState { + + import org.apache.spark.sql.execution.streaming.StateTTLSchema._ + + private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlKeyEncoder = UnsafeProjection.create(TTL_KEY_ROW_SCHEMA) + + // empty row used for values + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) + + store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, TTL_VALUE_ROW_SCHEMA, + RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, 1), isInternal = true) + + def upsertTTLForStateKey( + expirationMs: Long, + groupingKey: Array[Byte]): Unit = { + val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey)) + store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName) + } + + /** + * Clears any state which has ttl older than [[ttlExpirationMs]]. + */ + override def clearExpiredState(): Unit = { + val iterator = store.iterator(ttlColumnFamilyName) + + iterator.takeWhile { kv => + val expirationMs = kv.key.getLong(0) + StateTTL.isExpired(expirationMs, ttlExpirationMs) + }.foreach { kv => + val groupingKey = kv.key.getBinary(1) + clearIfExpired(groupingKey) + store.remove(kv.key, ttlColumnFamilyName) + } + } + + private[sql] def ttlIndexIterator(): Iterator[SingleKeyTTLRow] = { + val ttlIterator = store.iterator(ttlColumnFamilyName) + + new Iterator[SingleKeyTTLRow] { + override def hasNext: Boolean = ttlIterator.hasNext + + override def next(): SingleKeyTTLRow = { + val kv = ttlIterator.next() + SingleKeyTTLRow( + expirationMs = kv.key.getLong(0), + groupingKey = kv.key.getBinary(1) + ) + } + } + } +} + +/** + * Helper methods for user State TTL. + */ +object StateTTL { + def calculateExpirationTimeForDuration( + ttlDuration: Duration, + batchTtlExpirationMs: Long): Long = { + batchTtlExpirationMs + ttlDuration.toMillis + } + + def isExpired( + expirationMs: Long, + batchTtlExpirationMs: Long): Boolean = { + batchTtlExpirationMs >= expirationMs + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index af321eecb4db..8d410b677c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -78,25 +78,25 @@ class TimerStateImpl( private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex) - val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { + private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) { TimerStateUtils.PROC_TIMERS_STATE_NAME } else { TimerStateUtils.EVENT_TIMERS_STATE_NAME } - val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF + private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow, schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1), useMultipleValuesPerKey = false, isInternal = true) - val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF + private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex, schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1), useMultipleValuesPerKey = false, isInternal = true) private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption - if (!keyOption.isDefined) { + if (keyOption.isEmpty) { throw StateStoreErrors.implicitKeyNotFound(cfName) } keyOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 66c19fa22304..eaf51614d7cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.streaming._ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -42,6 +42,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param statefulProcessor processor methods called on underlying data + * @param ttlMode defines the ttl Mode for user state * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor * @param keyEncoder expression encoder for the key type @@ -58,6 +59,7 @@ case class TransformWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -78,17 +80,14 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { - timeoutMode match { + if (ttlMode == TTLMode.ProcessingTimeTTL() || timeoutMode == TimeoutMode.ProcessingTime()) { // TODO: check if we can return true only if actual timers are registered - case ProcessingTime => - true - - case EventTime => - eventTimeWatermarkForEviction.isDefined && - newInputWatermark > eventTimeWatermarkForEviction.get - - case _ => - false + true + } else if (timeoutMode == TimeoutMode.EventTime()) { + eventTimeWatermarkForEviction.isDefined && + newInputWatermark > eventTimeWatermarkForEviction.get + } else { + false } } @@ -102,10 +101,6 @@ case class TransformWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes - protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - /** * Distribute by grouping attributes - We need the underlying data and the initial state data * to have the same grouping so that the data are co-located on the same task. @@ -284,6 +279,8 @@ case class TransformWithStateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { if (isStreaming) { + // clean up any expired user state + processorHandle.doTtlCleanup() store.commit() } else { store.abort() @@ -300,19 +297,8 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - timeoutMode match { - case ProcessingTime => - if (batchTimestampMs.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case EventTime => - if (eventTimeWatermarkForEviction.isEmpty) { - StateStoreErrors.missingTimeoutValues(timeoutMode.toString) - } - - case _ => - } + validateTTLMode() + validateTimeoutMode() if (hasInitialState) { val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) @@ -332,9 +318,9 @@ case class TransformWithStateExec( val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) val store = StateStore.get( storeProviderId = storeProviderId, - keySchema = schemaForKeyRow, - valueSchema = schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), version = stateInfo.get.storeVersion, useColumnFamilies = true, storeConf = storeConf, @@ -352,9 +338,9 @@ case class TransformWithStateExec( if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator), useColumnFamilies = true @@ -402,9 +388,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - schemaForKeyRow, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(schemaForKeyRow), + KEY_ROW_SCHEMA, + VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useColumnFamilies = true, storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value, @@ -427,10 +413,11 @@ case class TransformWithStateExec( private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( - store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming) + store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode, + isStreaming, batchTimestampMs) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } @@ -441,10 +428,10 @@ case class TransformWithStateExec( initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeoutMode, isStreaming) + keyEncoder, ttlMode, timeoutMode, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) - statefulProcessor.init(outputMode, timeoutMode) + statefulProcessor.init(outputMode, timeoutMode, ttlMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) // Check if is first batch @@ -462,9 +449,36 @@ case class TransformWithStateExec( processDataWithPartition(childDataIterator, store, processorHandle) } + + private def validateTimeoutMode(): Unit = { + timeoutMode match { + case ProcessingTime => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case EventTime => + if (eventTimeWatermarkForEviction.isEmpty) { + StateStoreErrors.missingTimeoutValues(timeoutMode.toString) + } + + case _ => + } + } + + private def validateTTLMode(): Unit = { + ttlMode match { + case ProcessingTimeTTL => + if (batchTimestampMs.isEmpty) { + StateStoreErrors.missingTTLValues(timeoutMode.toString) + } + + case _ => + } + } } -// scalastyle:off +// scalastyle:off argcount object TransformWithStateExec { // Plan logical transformWithState for batch queries @@ -474,6 +488,7 @@ object TransformWithStateExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], statefulProcessor: StatefulProcessor[Any, Any, Any], + ttlMode: TTLMode, timeoutMode: TimeoutMode, outputMode: OutputMode, keyEncoder: ExpressionEncoder[Any], @@ -499,6 +514,7 @@ object TransformWithStateExec { groupingAttributes, dataAttributes, statefulProcessor, + ttlMode, timeoutMode, outputMode, keyEncoder, @@ -516,4 +532,5 @@ object TransformWithStateExec { initialState) } } -// scalastyle:on +// scalastyle:on argcount + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 08876ca3032e..d916011245c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState @@ -29,7 +28,7 @@ import org.apache.spark.sql.streaming.ValueState * variables used in the streaming transformWithState operator. * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition - * @param keyEnc - Spark SQL encoder for key + * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored */ @@ -37,18 +36,22 @@ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) extends ValueState[S] with Logging { + valEncoder: Encoder[S]) + extends ValueState[S] with Logging { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + initialize() + + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + } /** Function to check if state exists. Returns true if present and false otherwise */ override def exists(): Boolean = { - getImpl() != null + get() != null } /** Function to return Option of value if exists and None otherwise */ @@ -58,7 +61,9 @@ class ValueStateImpl[S]( /** Function to return associated value with key if exists and null otherwise */ override def get(): S = { - val retRow = getImpl() + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + if (retRow != null) { stateTypesEncoder.decodeValue(retRow) } else { @@ -66,14 +71,12 @@ class ValueStateImpl[S]( } } - private def getImpl(): UnsafeRow = { - store.get(stateTypesEncoder.encodeGroupingKey(), stateName) - } - /** Function to update and overwrite state associated with given key */ override def update(newState: S): Unit = { - store.put(stateTypesEncoder.encodeGroupingKey(), - stateTypesEncoder.encodeValue(newState), stateName) + val encodedValue = stateTypesEncoder.encodeValue(newState) + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) } /** Function to remove state for given key */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala new file mode 100644 index 000000000000..d3c9eb9de204 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.streaming.{TTLConfig, ValueState} + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables (with ttl expiration support) used in the streaming transformWithState operator. + * + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyExprEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @param ttlConfig - TTL configuration for values stored in this state + * @param batchTimestampMs - current batch processing timestamp. + * @tparam S - data type of object that will be stored + */ +class ValueStateImplWithTTL[S]( + store: StateStore, + stateName: String, + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S], + ttlConfig: TTLConfig, + batchTimestampMs: Long) + extends SingleKeyTTLStateImpl(stateName, store, batchTimestampMs) with ValueState[S] { + + private val keySerializer = keyExprEnc.createSerializer() + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, + stateName, hasTtl = true) + private val ttlExpirationMs = + StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + + initialize() + + private def initialize(): Unit = { + store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + } + + /** Function to check if state exists. Returns true if present and false otherwise */ + override def exists(): Boolean = { + get() != null + } + + /** Function to return Option of value if exists and None otherwise */ + override def getOption(): Option[S] = { + Option(get()) + } + + /** Function to return associated value with key if exists and null otherwise */ + override def get(): S = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + + if (!isExpired(retRow)) { + resState + } else { + null.asInstanceOf[S] + } + } else { + null.asInstanceOf[S] + } + } + + /** Function to update and overwrite state associated with given key */ + override def update(newState: S): Unit = { + val encodedValue = stateTypesEncoder.encodeValue(newState, ttlExpirationMs) + val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey() + store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey), + encodedValue, stateName) + upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) + } + + /** Function to remove state for given key */ + override def clear(): Unit = { + store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + } + + def clearIfExpired(groupingKey: Array[Byte]): Unit = { + val encodedGroupingKey = stateTypesEncoder.encodeSerializedGroupingKey(groupingKey) + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + if (isExpired(retRow)) { + store.remove(encodedGroupingKey, stateName) + } + } + } + + private def isExpired(valueRow: UnsafeRow): Boolean = { + val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow) + expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs)) + } + + /* + * Internal methods to probe state for testing. The below methods exist for unit tests + * to read the state ttl values, and ensure that values are persisted correctly in + * the underlying state store. + */ + + /** + * Retrieves the value from State even if its expired. This method is used + * in tests to read the state store value, and ensure if its cleaned up at the + * end of the micro-batch. + */ + private[sql] def getWithoutEnforcingTTL(): Option[S] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + val resState = stateTypesEncoder.decodeValue(retRow) + Some(resState) + } else { + None + } + } + + /** + * Read the ttl value associated with the grouping key. + */ + private[sql] def getTTLValue(): Option[Long] = { + val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey() + val retRow = store.get(encodedGroupingKey, stateName) + + if (retRow != null) { + stateTypesEncoder.decodeTtlExpirationMs(retRow) + } else { + None + } + } + + /** + * Get all ttl values stored in ttl state for current implicit + * grouping key. + */ + private[sql] def getValuesInTTLState(): Iterator[Long] = { + val ttlIterator = ttlIndexIterator() + val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey() + var nextValue: Option[Long] = None + + new Iterator[Long] { + override def hasNext: Boolean = { + while (nextValue.isEmpty && ttlIterator.hasNext) { + val nextTtlValue = ttlIterator.next() + val groupingKey = nextTtlValue.groupingKey + if (groupingKey sameElements implicitGroupingKey) { + nextValue = Some(nextTtlValue.expirationMs) + } + } + nextValue.isDefined + } + + override def next(): Long = { + val result = nextValue.get + nextValue = None + result + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 2f72cbb0b0fc..6c63aa94e75b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -39,6 +39,13 @@ object StateStoreErrors { ) } + def missingTTLValues(ttlMode: String): SparkException = { + SparkException.internalError( + msg = s"Failed to find timeout values for ttlMode=$ttlMode", + category = "TWS" + ) + } + def unsupportedOperationOnMissingColumnFamily(operationName: String, colFamilyName: String): StateStoreUnsupportedOperationOnMissingColumnFamily = { new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, colFamilyName) @@ -117,6 +124,16 @@ object StateStoreErrors { StatefulProcessorCannotReInitializeState = { new StatefulProcessorCannotReInitializeState(groupingKey) } + + def cannotProvideTTLConfigForNoTTLMode(stateName: String): + StatefulProcessorCannotAssignTTLInNoTTLMode = { + new StatefulProcessorCannotAssignTTLInNoTTLMode(stateName) + } + + def ttlMustBePositive(operationType: String, + stateName: String): StatefulProcessorTTLMustBePositive = { + new StatefulProcessorTTLMustBePositive(operationType, stateName) + } } class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) @@ -192,3 +209,15 @@ class StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", messageParameters = Map("fieldName" -> fieldName, "index" -> index)) + +class StatefulProcessorCannotAssignTTLInNoTTLMode(stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE", + messageParameters = Map("stateName" -> stateName)) + +class StatefulProcessorTTLMustBePositive( + operationType: String, + stateName: String) + extends SparkUnsupportedOperationException( + errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", + messageParameters = Map("operationType" -> operationType, "stateName" -> stateName)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 02927f1d962f..f9f075f4468d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -207,6 +207,7 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> transformWithStateMapped = grouped.transformWithState( new TestStatefulProcessorWithInitialState(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append(), kvInitStateMappedDS, Encoders.STRING(), @@ -362,6 +363,7 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> transformWithStateMapped = grouped.transformWithState( testStatefulProcessor, TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append(), Encoders.STRING()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java index 3122e0e337a3..c6d705af5f2d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java @@ -36,7 +36,10 @@ public class TestStatefulProcessor extends StatefulProcessor<Integer, String, St private transient ListState<String> keysList; @Override - public void init(OutputMode outputMode, TimeoutMode timeoutMode) { + public void init( + OutputMode outputMode, + TimeoutMode timeoutMode, + TTLMode ttlMode) { countState = this.getHandle().getValueState("countState", Encoders.LONG()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java index 247bae3a3f3c..db0b222145c4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java @@ -35,7 +35,10 @@ public class TestStatefulProcessorWithInitialState private transient ValueState<String> testState; @Override - public void init(OutputMode outputMode, TimeoutMode timeoutMode) { + public void init( + OutputMode outputMode, + TimeoutMode timeoutMode, + TTLMode ttlMode) { testState = this.getHandle().getValueState("testState", Encoders.STRING()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index e895e475b74d..51cfc1548b39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, ValueState} /** * Class that adds unit tests for ListState types used in arbitrary stateful @@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase { } checkError( - exception = e.asInstanceOf[SparkIllegalArgumentException], + exception = e, errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") @@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index ce72061d39ea..7fa41b12795e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} -import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, TTLMode, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} /** @@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 662a5dbfaac4..a32b4111eae8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.streaming.state +import java.time.Duration import java.util.UUID import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.streaming.TimeoutMode +import org.apache.spark.sql.streaming.{TimeoutMode, TTLConfig, TTLMode} + /** * Class that adds tests to verify operations based on stateful processor handle @@ -48,7 +50,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -89,7 +91,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -107,7 +109,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } @@ -143,7 +145,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -164,7 +166,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -204,7 +206,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), getTimeoutMode(timeoutMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, @@ -216,4 +218,34 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { } } } + + test(s"ttl States are populated for ttlMode=ProcessingTime") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(), + batchTimestampMs = Some(10)) + + val valueStateWithTTL = handle.getValueState("testState", + Encoders.STRING, TTLConfig(Duration.ofHours(1))) + + // create another state without TTL, this should not be captured in the handle + handle.getValueState("testState", Encoders.STRING) + + assert(handle.ttlStates.size() === 1) + assert(handle.ttlStates.get(0) === valueStateWithTTL) + } + } + + test(s"ttl States are not populated for ttlMode=NoTTL") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, + UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) + + handle.getValueState("testState", Encoders.STRING) + + assert(handle.ttlStates.isEmpty) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8668b58672c7..102164d9c15f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.time.Duration import java.util.UUID import scala.util.Random @@ -27,9 +28,9 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, ValueStateImplWithTTL} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{TimeoutMode, ValueState} +import org.apache.spark.sql.streaming.{TimeoutMode, TTLConfig, TTLMode, ValueState} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -48,7 +49,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -78,7 +80,7 @@ class ValueStateSuite extends StateVariableSuiteBase { testState.update(123) } checkError( - ex.asInstanceOf[SparkException], + ex1.asInstanceOf[SparkException], errorClass = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" @@ -92,7 +94,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -118,7 +121,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,7 +168,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], - TimeoutMode.NoTimeouts()) + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val cfName = "_testState" val ex = intercept[SparkUnsupportedOperationException] { @@ -204,7 +208,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +235,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +262,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +289,8 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.NoTTL(), TimeoutMode.NoTimeouts()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) @@ -303,6 +311,93 @@ class ValueStateSuite extends StateVariableSuiteBase { assert(testState.get() === null) } } + + + test(s"test Value state TTL") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val timestampMs = 10 + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(), + batchTimestampMs = Some(timestampMs)) + + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val testState: ValueStateImplWithTTL[String] = handle.getValueState[String]("testState", + Encoders.STRING, ttlConfig).asInstanceOf[ValueStateImplWithTTL[String]] + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + testState.update("v1") + assert(testState.get() === "v1") + assert(testState.getWithoutEnforcingTTL().get === "v1") + + val ttlExpirationMs = timestampMs + 60000 + var ttlValue = testState.getTTLValue() + assert(ttlValue.isDefined) + assert(ttlValue.get === ttlExpirationMs) + var ttlStateValueIterator = testState.getValuesInTTLState() + assert(ttlStateValueIterator.hasNext) + + // increment batchProcessingTime, or watermark and ensure expired value is not returned + val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(), + batchTimestampMs = Some(ttlExpirationMs)) + + val nextBatchTestState: ValueStateImplWithTTL[String] = + nextBatchHandle.getValueState[String]("testState", Encoders.STRING, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[String]] + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + + // ensure get does not return the expired value + assert(!nextBatchTestState.exists()) + assert(nextBatchTestState.get() === null) + + // ttl value should still exist in state + ttlValue = nextBatchTestState.getTTLValue() + assert(ttlValue.isDefined) + assert(ttlValue.get === ttlExpirationMs) + ttlStateValueIterator = nextBatchTestState.getValuesInTTLState() + assert(ttlStateValueIterator.hasNext) + assert(ttlStateValueIterator.next() === ttlExpirationMs) + assert(ttlStateValueIterator.isEmpty) + + // getWithoutTTL should still return the expired value + assert(nextBatchTestState.getWithoutEnforcingTTL().get === "v1") + + nextBatchTestState.clear() + assert(!nextBatchTestState.exists()) + assert(nextBatchTestState.get() === null) + } + } + + test("test negative or zero TTL duration throws error") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val batchTimestampMs = 10 + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(), + batchTimestampMs = Some(batchTimestampMs)) + + Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration => + val ttlConfig = TTLConfig(ttlDuration) + val ex = intercept[SparkUnsupportedOperationException] { + handle.getValueState[String]("testState", Encoders.STRING, ttlConfig) + } + + checkError( + ex, + errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", + parameters = Map( + "operationType" -> "update", + "stateName" -> "testState" + ), + matchPVals = true + ) + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 95ab34d40131..5ccc14ab8a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -32,7 +32,8 @@ class TestListStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) } @@ -89,7 +90,8 @@ class ToggleSaveAndEmitProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _listState = getHandle.getListState("testListState", Encoders.STRING) _valueState = getHandle.getValueState("testValueState", Encoders.scalaBoolean) } @@ -140,6 +142,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -160,6 +163,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -180,6 +184,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -200,6 +205,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -220,6 +226,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -240,6 +247,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -260,6 +268,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestListStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update()) ( @@ -312,6 +321,7 @@ class TransformWithListStateSuite extends StreamTest .groupByKey(x => x) .transformWithState(new ToggleSaveAndEmitProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index db8cb8b810af..d32b9687d95f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -32,7 +32,8 @@ class TestMapStateProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _mapState = getHandle.getMapState("sessionState", Encoders.STRING, Encoders.STRING) } @@ -95,6 +96,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) @@ -121,6 +123,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -145,6 +148,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -168,6 +172,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) testStream(result, OutputMode.Append())( // Test exists() @@ -222,6 +227,7 @@ class TransformWithMapStateSuite extends StreamTest .groupByKey(x => x.key) .transformWithState(new TestMapStateProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 147a13251044..106f228ba78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -36,7 +36,10 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] @transient var _listState: ListState[Double] = _ @transient var _mapState: MapState[Double, Int] = _ - override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = { + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble) _listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble) _mapState = getHandle.getMapState[Double, Int]( @@ -168,7 +171,8 @@ class StatefulProcessorWithInitialStateProcTimerClass override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode) : Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } @@ -211,7 +215,8 @@ class StatefulProcessorWithInitialStateEventTimerClass override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState", Encoders.scalaLong) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) @@ -288,7 +293,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 1))) .toDS().groupByKey(x => x.key).mapValues(x => x) val query = kvDataSet.transformWithState(new InitialStateInMemoryTestClass(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf) + TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), initStateDf) testStream(query, OutputMode.Update())( // non-exist key test @@ -366,7 +371,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest val query = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf + TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), initStateDf ) testStream(query, OutputMode.Update())( AddData(inputData, InitInputRow("init_1", "add", 50.0)), @@ -387,6 +392,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append(), createInitialDfForTest) @@ -405,6 +411,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append(), initDf) @@ -437,6 +444,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest .transformWithState( new StatefulProcessorWithInitialStateProcTimerClass(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update(), initDf) @@ -481,6 +489,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest .transformWithState( new StatefulProcessorWithInitialStateEventTimerClass(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update(), initDf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 2fd1eac179da..735c53bf3c91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -40,7 +40,8 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) } @@ -103,8 +104,9 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode) : Unit = { - super.init(outputMode, timeoutMode) + timeoutMode: TimeoutMode, + ttlMode: TTLMode) : Unit = { + super.init(outputMode, timeoutMode, ttlMode) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } @@ -194,7 +196,8 @@ class MaxEventTimeStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState", Encoders.scalaLong) _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) @@ -239,10 +242,12 @@ class RunningCountMostRecentStatefulProcessor override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } + override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -268,7 +273,8 @@ class MostRecentStatefulProcessorWithDeletion override def init( outputMode: OutputMode, - timeoutMode: TimeoutMode): Unit = { + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { getHandle.deleteIfExists("countState") _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } @@ -322,6 +328,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithError(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -343,6 +350,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -373,6 +381,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -416,6 +425,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -452,6 +462,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new RunningCountStatefulProcessorWithMultipleTimers(), TimeoutMode.ProcessingTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -487,6 +498,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState( new MaxEventTimeStatefulProcessor(), TimeoutMode.EventTime(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -528,6 +540,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Append()) val df = result.toDF() @@ -546,12 +559,14 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x._1) .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) val stream2 = inputData.toDS() .groupByKey(x => x._1) .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(stream1, OutputMode.Update())( @@ -584,6 +599,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -617,6 +633,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -650,6 +667,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -680,6 +698,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) } @@ -772,6 +791,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { .groupByKey(x => x) .transformWithState(new RunningCountStatefulProcessor(), TimeoutMode.NoTimeouts(), + TTLMode.NoTTL(), OutputMode.Update()) testStream(result, OutputMode.Update())( @@ -790,7 +810,7 @@ class TransformWithStateValidationSuite extends StateStoreMetricsTest { val result = inputData.toDS() .groupByKey(x => x.key) .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf + TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), initDf ) testStream(result, OutputMode.Update())( AddData(inputData, InitInputRow("a", "add", -1.0)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala new file mode 100644 index 000000000000..759d535c18a3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Timestamp +import java.time.Duration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock + +case class InputEvent( + key: String, + action: String, + value: Int, + eventTime: Timestamp = null) + +case class OutputEvent( + key: String, + value: Int, + isTTLValue: Boolean, + ttlValue: Long) + +object TTLInputProcessFunction { + def processRow( + row: InputEvent, + valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_without_enforcing_ttl") { + val currState = valueState.getWithoutEnforcingTTL() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "get_ttl_value_from_state") { + val ttlExpiration = valueState.getTTLValue() + if (ttlExpiration.isDefined) { + results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) :: results + } + } else if (row.action == "put") { + valueState.update(row.value) + } else if (row.action == "get_values_in_ttl_state") { + val ttlValues = valueState.getValuesInTTLState() + ttlValues.foreach { v => + results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: results + } + } + + results.iterator + } + + def processNonTTLStateRow( + row: InputEvent, + valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + val key = row.key + if (row.action == "get") { + val currState = valueState.getOption() + if (currState.isDefined) { + results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: results + } + } else if (row.action == "put") { + valueState.update(row.value) + } + + results.iterator + } +} + +class ValueStateTTLProcessor(ttlConfig: TTLConfig) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueState: ValueStateImplWithTTL[Int] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _valueState = getHandle + .getValueState("valueState", Encoders.scalaInt, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + inputRows.foreach { row => + val resultIter = TTLInputProcessFunction.processRow(row, _valueState) + resultIter.foreach { r => + results = r :: results + } + } + + results.iterator + } +} + +case class MultipleValueStatesTTLProcessor( + ttlKey: String, + noTtlKey: String, + ttlConfig: TTLConfig) + extends StatefulProcessor[String, InputEvent, OutputEvent] + with Logging { + + @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _ + @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _ + + override def init( + outputMode: OutputMode, + timeoutMode: TimeoutMode, + ttlMode: TTLMode): Unit = { + _valueStateWithTTL = getHandle + .getValueState("valueState", Encoders.scalaInt, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[Int]] + _valueStateWithoutTTL = getHandle + .getValueState("valueState", Encoders.scalaInt) + .asInstanceOf[ValueStateImpl[Int]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputEvent], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + var results = List[OutputEvent]() + + if (key == ttlKey) { + inputRows.foreach { row => + val resultIterator = TTLInputProcessFunction.processRow(row, _valueStateWithTTL) + resultIterator.foreach { r => + results = r :: results + } + } + } else { + inputRows.foreach { row => + val resultIterator = TTLInputProcessFunction.processNonTTLStateRow(row, + _valueStateWithoutTTL) + resultIterator.foreach { r => + results = r :: results + } + } + } + + results.iterator + } +} + +/** + * Tests that ttl works as expected for Value State for + * processing time and event time based ttl. + */ +class TransformWithValueStateTTLSuite + extends StreamTest { + import testImplicits._ + + test("validate state is evicted at ttl expiry") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { dir => + val inputStream = MemoryStream[InputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(ttlConfig), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = dir.getAbsolutePath), + AddData(inputStream, InputEvent("k1", "put", 1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + StopStream, + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = dir.getAbsolutePath), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + StopStream, + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = dir.getAbsolutePath), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + StopStream, + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = dir.getAbsolutePath), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent("k1", "get", -1, null)), + AdvanceManualClock(1 * 1000), + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + } + + test("validate state update updates the expiration timestamp") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val inputStream = MemoryStream[InputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(ttlConfig), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock and update expiration time + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "put", 1)), + AddData(inputStream, InputEvent("k1", "get", -1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate value is not expired + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate ttl value is updated in the state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)), + // validate ttl state has both ttl values present + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000), + OutputEvent("k1", -1, isTTLValue = true, 95000) + ), + // advance clock after older expiration value + AdvanceManualClock(30 * 1000), + // ensure unexpired value is still present in the state + AddData(inputStream, InputEvent("k1", "get", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // validate that the older expiration value is removed from ttl state + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)) + ) + } + } + + test("validate state is evicted at ttl expiry for no data batch") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val inputStream = MemoryStream[InputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(ttlConfig), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get this state, and make sure we get unexpired value + AddData(inputStream, InputEvent("k1", "get", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)), + // ensure ttl values were added correctly + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)), + // advance clock so that state expires + AdvanceManualClock(60 * 1000), + // run a no data batch + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get", -1)), + AdvanceManualClock(1 * 1000), + // validate expired value is not returned + CheckNewAnswer(), + // ensure this state does not exist any longer in State + AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate multiple value states") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val ttlKey = "k1" + val noTtlKey = "k2" + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val inputStream = MemoryStream[InputEvent] + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent(ttlKey, "put", 1)), + AddData(inputStream, InputEvent(noTtlKey, "put", 2)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // get both state values, and make sure we get unexpired value + AddData(inputStream, InputEvent(ttlKey, "get", -1)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent(ttlKey, 1, isTTLValue = false, -1), + OutputEvent(noTtlKey, 2, isTTLValue = false, -1) + ), + // ensure ttl values were added correctly, and noTtlKey has no ttl values + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1)), + AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", -1)), + AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)), + // advance clock after expiry + AdvanceManualClock(60 * 1000), + AddData(inputStream, InputEvent(ttlKey, "get", -1)), + AddData(inputStream, InputEvent(noTtlKey, "get", -1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + // validate ttlKey is expired, bot noTtlKey is still present + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + // validate ttl value is removed in the value state column family + AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer() + ) + } + } + + test("validate only expired keys are removed from the state") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inputStream = MemoryStream[InputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new ValueStateTTLProcessor(ttlConfig), + TimeoutMode.NoTimeouts(), + TTLMode.ProcessingTimeTTL(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputStream, InputEvent("k1", "put", 1)), + // advance clock to trigger processing + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // advance clock halfway to expiration ttl, and add another key + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k2", "put", 2)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // advance clock so that key k1 is expired + AdvanceManualClock(30 * 1000), + AddData(inputStream, InputEvent("k1", "get", 1)), + AddData(inputStream, InputEvent("k2", "get", -1)), + AdvanceManualClock(1 * 1000), + // validate k1 is expired and k2 is not + CheckNewAnswer(OutputEvent("k2", 2, isTTLValue = false, -1)), + // validate k1 is deleted from state + AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)), + AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + // validate k2 exists in state + AddData(inputStream, InputEvent("k2", "get_ttl_value_from_state", -1)), + AddData(inputStream, InputEvent("k2", "get_values_in_ttl_state", -1)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + OutputEvent("k2", -1, isTTLValue = true, 92000), + OutputEvent("k2", -1, isTTLValue = true, 92000)) + ) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org