This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d55bb617a135 [SPARK-47558][SS] State TTL support for ValueState
d55bb617a135 is described below

commit d55bb617a13561f0eb9f301089a4e4fb06e06228
Author: Bhuwan Sahni <bhuwan.sa...@databricks.com>
AuthorDate: Mon Apr 8 12:22:04 2024 +0900

    [SPARK-47558][SS] State TTL support for ValueState
    
    **Note**: This change has been co-authored by ericm-db  and sahnib
    
    **Authors: ericm-db sahnib**
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for expiring state based on TTL for ValueState. Using 
this functionality, Spark users can specify a TTL Mode for transformWithState 
operator, and provide a ttlDuration/expirationTImeInMs for each value in 
ValueState. TTL support for List/Map State will be added in future PRs. Once 
the ttlDuration has expired, the value will not be returned as part of `get()` 
and would be cleaned up at the end of the micro-batch.
    
    ### Why are the changes needed?
    
    These changes are needed to support TTL for ValueState. The PR supports 
specifying ttl for processing time or event time. Processing time ttl is 
calculated by adding ttlDuration to `batchTimestamp`, and event time ttl is 
specified using absolute expiration time (`expirationTimeInMs`).
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, modifies the ValueState interface for specifying `ttlDuration`, and 
adds `ttlMode` to `transformWithState` API.
    
    ### How was this patch tested?
    
    Added unit test cases for both event time and processing time in 
`ValueStateWithTTLSuite`.
    
    ```
    WARNING: Using incubator modules: jdk.incubator.foreign, 
jdk.incubator.vector
    [info] TransformWithStateTTLSuite:
    11:56:54.590 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load 
native-hadoop library for your platform... using builtin-java classes where 
applicable
    11:56:56.054 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate state is evicted at ttl expiry - processing time ttl (6 
seconds, 244 milliseconds)
    11:57:01.188 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate ttl update updates the expiration timestamp - processing 
time ttl (4 seconds, 465 milliseconds)
    11:57:05.641 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate ttl removal keeps value in state - processing time ttl (4 
seconds, 407 milliseconds)
    11:57:10.041 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate multiple value states - with and without ttl - processing 
time ttl (3 seconds, 131 milliseconds)
    11:57:13.175 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate state is evicted at ttl expiry - event time ttl (4 
seconds, 186 milliseconds)
    11:57:17.355 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate ttl update updates the expiration timestamp - event time 
ttl (4 seconds, 28 milliseconds)
    11:57:21.391 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - validate ttl removal keeps value in state - event time ttl (4 
seconds, 428 milliseconds)
    11:57:25.838 WARN org.apache.spark.sql.streaming.TransformWithStateTTLSuite:
    
    [info] Run completed in 32 seconds, 433 milliseconds.
    [info] Total number of tests run: 7
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 7, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45674 from sahnib/state-ttl.
    
    Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |  17 +
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  14 +-
 dev/checkstyle-suppressions.xml                    |   2 +
 ...r-conditions-unsupported-feature-error-class.md |   4 +
 docs/sql-error-conditions.md                       |  12 +
 .../org/apache/spark/sql/streaming/TTLMode.java}   |  40 +-
 .../plans/logical/TTLMode.scala}                   |  36 +-
 .../spark/sql/streaming/StatefulProcessor.scala    |   3 +-
 .../sql/streaming/StatefulProcessorHandle.scala    |  25 +-
 .../{ValueState.scala => TTLConfig.scala}          |  36 +-
 .../apache/spark/sql/streaming/ValueState.scala    |   6 +-
 .../spark/sql/catalyst/plans/logical/object.scala  |   7 +-
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  24 +-
 .../spark/sql/execution/SparkStrategies.scala      |   7 +-
 .../sql/execution/streaming/ListStateImpl.scala    |   2 +-
 .../sql/execution/streaming/MapStateImpl.scala     |   2 +-
 .../streaming/StateTypesEncoderUtils.scala         |  84 +++-
 .../streaming/StatefulProcessorHandleImpl.scala    |  63 ++-
 .../spark/sql/execution/streaming/TTLState.scala   | 153 +++++++
 .../sql/execution/streaming/TimerStateImpl.scala   |   8 +-
 .../streaming/TransformWithStateExec.scala         | 105 +++--
 .../sql/execution/streaming/ValueStateImpl.scala   |  33 +-
 .../streaming/ValueStateImplWithTTL.scala          | 184 ++++++++
 .../streaming/state/StateStoreErrors.scala         |  29 ++
 .../org/apache/spark/sql/JavaDatasetSuite.java     |   2 +
 .../apache/spark/sql/TestStatefulProcessor.java    |   5 +-
 .../sql/TestStatefulProcessorWithInitialState.java |   5 +-
 .../execution/streaming/state/ListStateSuite.scala |  16 +-
 .../execution/streaming/state/MapStateSuite.scala  |  11 +-
 .../state/StatefulProcessorHandleSuite.scala       |  46 +-
 .../streaming/state/ValueStateSuite.scala          | 117 ++++-
 .../streaming/TransformWithListStateSuite.scala    |  14 +-
 .../sql/streaming/TransformWithMapStateSuite.scala |   8 +-
 .../TransformWithStateInitialStateSuite.scala      |  19 +-
 .../sql/streaming/TransformWithStateSuite.scala    |  34 +-
 .../TransformWithValueStateTTLSuite.scala          | 471 +++++++++++++++++++++
 36 files changed, 1407 insertions(+), 237 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index aeb35b864c66..f28adaf40230 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3579,6 +3579,12 @@
     ],
     "sqlState" : "0A000"
   },
+  "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE" : {
+    "message" : [
+      "Cannot use TTL for state=<stateName> in NoTTL() mode."
+    ],
+    "sqlState" : "42802"
+  },
   "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : {
     "message" : [
       "Failed to perform stateful processor operation=<operationType> with 
invalid handle state=<handleState>."
@@ -3597,6 +3603,12 @@
     ],
     "sqlState" : "42802"
   },
+  "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE" : {
+    "message" : [
+      "TTL duration must be greater than zero for State store 
operation=<operationType> on state=<stateName>."
+    ],
+    "sqlState" : "42802"
+  },
   "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
     "message" : [
       "Failed to create column family with unsupported starting character and 
name=<colFamilyName>."
@@ -4391,6 +4403,11 @@
           "Removing column families with <stateStoreProvider> is not 
supported."
         ]
       },
+      "STATE_STORE_TTL" : {
+        "message" : [
+          "State TTL with <stateStoreProvider> is not supported. Please use 
RocksDBStateStoreProvider."
+        ]
+      },
       "TABLE_OPERATION" : {
         "message" : [
           "Table <tableName> does not support <operation>. Please check the 
current catalog and namespace to make sure the qualified table name is 
expected, and also check the catalog implementation which is configured by 
\"spark.sql.catalog\"."
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 1b712348d865..39e0c429046d 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
 import org.apache.spark.sql.connect.common.UdfUtils
 import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, 
TTLMode}
 
 /**
  * A [[Dataset]] has been logically grouped by a user specified grouping key. 
Users should not
@@ -829,12 +829,15 @@ class KeyValueGroupedDataset[K, V] private[sql] () 
extends Serializable {
    *   Instance of statefulProcessor whose functions will be invoked by the 
operator.
    * @param timeoutMode
    *   The timeout mode of the stateful processor.
+   * @param ttlMode
+   *   The ttlMode to evict user state on ttl expiration.
    * @param outputMode
    *   The output mode of the stateful processor.
    */
   def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode): Dataset[U] = {
     throw new UnsupportedOperationException
   }
@@ -853,6 +856,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
    *   Instance of statefulProcessor whose functions will be invoked by the 
operator.
    * @param timeoutMode
    *   The timeout mode of the stateful processor.
+   * @param ttlMode
+   *   The ttlMode to evict user state on ttl expiration.
    * @param outputMode
    *   The output mode of the stateful processor.
    * @param outputEncoder
@@ -861,6 +866,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
   def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       outputEncoder: Encoder[U]): Dataset[U] = {
     throw new UnsupportedOperationException
@@ -879,6 +885,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
    *   Instance of statefulProcessor whose functions will be invoked by the 
operator.
    * @param timeoutMode
    *   The timeout mode of the stateful processor.
+   * @param ttlMode
+   *   The ttlMode to evict user state on ttl expiration.
    * @param outputMode
    *   The output mode of the stateful processor.
    * @param initialState
@@ -890,6 +898,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
   def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
     throw new UnsupportedOperationException
@@ -908,6 +917,8 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
    *   Instance of statefulProcessor whose functions will be invoked by the 
operator.
    * @param timeoutMode
    *   The timeout mode of the stateful processor.
+   * @param ttlMode
+   *   The ttlMode to evict user state on ttl expiration
    * @param outputMode
    *   The output mode of the stateful processor.
    * @param initialState
@@ -923,6 +934,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S],
       outputEncoder: Encoder[U],
diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml
index 7b20dfb6bce5..94dfe20af56e 100644
--- a/dev/checkstyle-suppressions.xml
+++ b/dev/checkstyle-suppressions.xml
@@ -60,6 +60,8 @@
               
files="sql/api/src/main/java/org/apache/spark/sql/streaming/TimeoutMode.java"/>
     <suppress checks="MethodName"
               
files="sql/api/src/main/java/org/apache/spark/sql/streaming/Trigger.java"/>
+    <suppress checks="MethodName"
+              
files="sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java"/>
     <suppress checks="LineLength"
               files="src/main/java/org/apache/spark/sql/api/java/*"/>
     <suppress checks="IllegalImport"
diff --git a/docs/sql-error-conditions-unsupported-feature-error-class.md 
b/docs/sql-error-conditions-unsupported-feature-error-class.md
index e580ecc63b18..f67d7caff63d 100644
--- a/docs/sql-error-conditions-unsupported-feature-error-class.md
+++ b/docs/sql-error-conditions-unsupported-feature-error-class.md
@@ -202,6 +202,10 @@ Creating multiple column families with 
`<stateStoreProvider>` is not supported.
 
 Removing column families with `<stateStoreProvider>` is not supported.
 
+## STATE_STORE_TTL
+
+State TTL with `<stateStoreProvider>` is not supported. Please use 
RocksDBStateStoreProvider.
+
 ## TABLE_OPERATION
 
 Table `<tableName>` does not support `<operation>`. Please check the current 
catalog and namespace to make sure the qualified table name is expected, and 
also check the catalog implementation which is configured by 
"spark.sql.catalog".
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index ee3a3bd07a77..d8261b8c2765 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2183,6 +2183,12 @@ The SQL config `<sqlConf>` cannot be found. Please 
verify that the config exists
 
 Star (*) is not allowed in a select list when GROUP BY an ordinal position is 
used.
 
+### STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE
+
+[SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Cannot use TTL for state=`<stateName>` in NoTTL() mode.
+
 ### STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE
 
 [SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
@@ -2201,6 +2207,12 @@ Failed to perform stateful processor 
operation=`<operationType>` with invalid ti
 
 Cannot re-initialize state on the same grouping key during initial state 
handling for stateful processor. Invalid grouping key=`<groupingKey>`.
 
+### STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE
+
+[SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+TTL duration must be greater than zero for State store 
operation=`<operationType>` on state=`<stateName>`.
+
 ### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS
 
 [SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala 
b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java
similarity index 53%
copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
copy to sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java
index 9c707c8308ab..30594770b3e1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/java/org/apache/spark/sql/streaming/TTLMode.java
@@ -15,36 +15,28 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.streaming
+package org.apache.spark.sql.streaming;
 
-import java.io.Serializable
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.catalyst.plans.logical.*;
 
-import org.apache.spark.annotation.{Evolving, Experimental}
-
-@Experimental
-@Evolving
 /**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * single value state.
+ * Represents the type of ttl modes possible for the Dataset operations
+ * {@code transformWithState}.
  */
-private[sql] trait ValueState[S] extends Serializable {
-
-  /** Whether state exists or not. */
-  def exists(): Boolean
+@Experimental
+@Evolving
+public class TTLMode {
 
   /**
-   * Get the state value if it exists
-   * @throws java.util.NoSuchElementException if the state does not exist
+   * Specifies that there is no TTL for the user state. User state would not
+   * be cleaned up by Spark automatically.
    */
-  @throws[NoSuchElementException]
-  def get(): S
-
-  /** Get the state if it exists as an option and None otherwise */
-  def getOption(): Option[S]
-
-  /** Update the value of the state. */
-  def update(newState: S): Unit
+  public static final TTLMode NoTTL() { return NoTTL$.MODULE$; }
 
-  /** Remove this state. */
-  def clear(): Unit
+  /**
+   * Specifies that all ttl durations for user state are in processing time.
+   */
+  public static final TTLMode ProcessingTimeTTL() { return 
ProcessingTimeTTL$.MODULE$; }
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala
similarity index 50%
copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
copy to 
sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala
index 9c707c8308ab..be4794a5f40b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TTLMode.scala
@@ -14,37 +14,11 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.sql.catalyst.plans.logical
 
-package org.apache.spark.sql.streaming
+import org.apache.spark.sql.streaming.TTLMode
 
-import java.io.Serializable
+/** TTL types used in tranformWithState operator */
+case object NoTTL extends TTLMode
 
-import org.apache.spark.annotation.{Evolving, Experimental}
-
-@Experimental
-@Evolving
-/**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * single value state.
- */
-private[sql] trait ValueState[S] extends Serializable {
-
-  /** Whether state exists or not. */
-  def exists(): Boolean
-
-  /**
-   * Get the state value if it exists
-   * @throws java.util.NoSuchElementException if the state does not exist
-   */
-  @throws[NoSuchElementException]
-  def get(): S
-
-  /** Get the state if it exists as an option and None otherwise */
-  def getOption(): Option[S]
-
-  /** Update the value of the state. */
-  def update(newState: S): Unit
-
-  /** Remove this state. */
-  def clear(): Unit
-}
+case object ProcessingTimeTTL extends TTLMode
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 42d12dd91e94..70f9cdfa399a 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -44,7 +44,8 @@ private[sql] abstract class StatefulProcessor[K, I, O] 
extends Serializable {
    */
   def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit
 
   /**
    * Function that will allow users to interact with input data rows along 
with the grouping key
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 560188a0ff62..e65667206ded 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -30,16 +30,37 @@ import org.apache.spark.sql.Encoder
 private[sql] trait StatefulProcessorHandle extends Serializable {
 
   /**
-   * Function to create new or return existing single value state variable of 
given type
+   * Function to create new or return existing single value state variable of 
given type.
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
-   * @param stateName - name of the state variable
+   *
+   * @param stateName  - name of the state variable
    * @param valEncoder - SQL encoder for state variable
    * @tparam T - type of state variable
    * @return - instance of ValueState of type T that can be used to store 
state persistently
    */
   def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T]
 
+  /**
+   * Function to create new or return existing single value state variable of 
given type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any state update resets the ttl to current 
processing time plus
+   * ttlDuration.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T]
+
   /**
    * Creates new or returns existing list state associated with stateName.
    * The ListState persists values of type T.
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
similarity index 53%
copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
copy to sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
index 9c707c8308ab..576e09d5d7fe 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
@@ -17,34 +17,14 @@
 
 package org.apache.spark.sql.streaming
 
-import java.io.Serializable
+import java.time.Duration
 
-import org.apache.spark.annotation.{Evolving, Experimental}
-
-@Experimental
-@Evolving
 /**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * single value state.
+ * TTL Configuration for state variable. State values will not be returned 
past ttlDuration,
+ * and will be eventually removed from the state store. Any state update 
resets the ttl to
+ * current processing time plus ttlDuration.
+ *
+ * @param ttlDuration time to live duration for state
+ *                    stored in the state variable.
  */
-private[sql] trait ValueState[S] extends Serializable {
-
-  /** Whether state exists or not. */
-  def exists(): Boolean
-
-  /**
-   * Get the state value if it exists
-   * @throws java.util.NoSuchElementException if the state does not exist
-   */
-  @throws[NoSuchElementException]
-  def get(): S
-
-  /** Get the state if it exists as an option and None otherwise */
-  def getOption(): Option[S]
-
-  /** Update the value of the state. */
-  def update(newState: S): Unit
-
-  /** Remove this state. */
-  def clear(): Unit
-}
+case class TTLConfig(ttlDuration: Duration)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
index 9c707c8308ab..8a2661e1a55b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
@@ -42,7 +42,11 @@ private[sql] trait ValueState[S] extends Serializable {
   /** Get the state if it exists as an option and None otherwise */
   def getOption(): Option[S]
 
-  /** Update the value of the state. */
+  /**
+   * Update the value of the state.
+   *
+   * @param newState    the new value
+   */
   def update(newState: S): Unit
 
   /** Remove this state. */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index b2c443a8cce0..ff7c8fb3df4b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StatefulProcessor, TimeoutMode}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StatefulProcessor, TimeoutMode, TTLMode}
 import org.apache.spark.sql.types._
 
 object CatalystSerde {
@@ -574,6 +574,7 @@ object TransformWithState {
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
       statefulProcessor: StatefulProcessor[K, V, U],
+      ttlMode: TTLMode,
       timeoutMode: TimeoutMode,
       outputMode: OutputMode,
       child: LogicalPlan): LogicalPlan = {
@@ -584,6 +585,7 @@ object TransformWithState {
       groupingAttributes,
       dataAttributes,
       statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
+      ttlMode,
       timeoutMode,
       outputMode,
       keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
@@ -605,6 +607,7 @@ object TransformWithState {
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
       statefulProcessor: StatefulProcessor[K, V, U],
+      ttlMode: TTLMode,
       timeoutMode: TimeoutMode,
       outputMode: OutputMode,
       child: LogicalPlan,
@@ -618,6 +621,7 @@ object TransformWithState {
       groupingAttributes,
       dataAttributes,
       statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
+      ttlMode,
       timeoutMode,
       outputMode,
       keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
@@ -639,6 +643,7 @@ case class TransformWithState(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     statefulProcessor: StatefulProcessor[Any, Any, Any],
+    ttlMode: TTLMode,
     timeoutMode: TimeoutMode,
     outputMode: OutputMode,
     keyEncoder: ExpressionEncoder[Any],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 55ac3daa6209..f3713edd0ec0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.ReduceAggregator
 import org.apache.spark.sql.internal.TypedAggUtils
-import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode, 
TTLMode}
 
 /**
  * A [[Dataset]] has been logically grouped by a user specified grouping key.  
Users should not
@@ -652,16 +652,18 @@ class KeyValueGroupedDataset[K, V] private[sql](
    * invocations.
    *
    * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
-   * @param statefulProcessor Instance of statefulProcessor whose functions 
will be invoked by the
-   *                          operator.
-   * @param timeoutMode The timeout mode of the stateful processor.
-   * @param outputMode The output mode of the stateful processor.
+   * @param statefulProcessor Instance of statefulProcessor whose functions 
will be invoked
+   *                          by the operator.
+   * @param timeoutMode       The timeout mode of the stateful processor.
+   * @param ttlMode           The ttlMode to evict user state on ttl expiration
+   * @param outputMode        The output mode of the stateful processor.
    *
    * See [[Encoder]] for more details on what types are encodable to Spark SQL.
    */
   private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode): Dataset[U] = {
     Dataset[U](
       sparkSession,
@@ -669,6 +671,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
         groupingAttributes,
         dataAttributes,
         statefulProcessor,
+        ttlMode,
         timeoutMode,
         outputMode,
         child = logicalPlan
@@ -689,6 +692,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
    * @param statefulProcessor Instance of statefulProcessor whose functions 
will be invoked by the
    *                          operator.
    * @param timeoutMode The timeout mode of the stateful processor.
+   * @param ttlMode The ttlMode to evict user state on ttl expiration
    * @param outputMode The output mode of the stateful processor.
    * @param outputEncoder Encoder for the output type.
    *
@@ -697,9 +701,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
   private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       outputEncoder: Encoder[U]): Dataset[U] = {
-    transformWithState(statefulProcessor, timeoutMode, 
outputMode)(outputEncoder)
+    transformWithState(statefulProcessor, timeoutMode, ttlMode, 
outputMode)(outputEncoder)
   }
 
   /**
@@ -712,6 +717,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
    * @param statefulProcessor Instance of statefulProcessor whose functions 
will
    *                          be invoked by the operator.
    * @param timeoutMode       The timeout mode of the stateful processor.
+   * @param ttlMode           The ttlMode to evict user state on ttl expiration
    * @param outputMode        The output mode of the stateful processor.
    * @param initialState      User provided initial state that will be used to 
initiate state for
    *                          the query in the first batch.
@@ -721,6 +727,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
     Dataset[U](
@@ -729,6 +736,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
         groupingAttributes,
         dataAttributes,
         statefulProcessor,
+        ttlMode,
         timeoutMode,
         outputMode,
         child = logicalPlan,
@@ -749,6 +757,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
    * @param statefulProcessor Instance of statefulProcessor whose functions 
will
    *                          be invoked by the operator.
    * @param timeoutMode       The timeout mode of the stateful processor.
+   * @param ttlMode           The ttlMode to evict user state on ttl expiration
    * @param outputMode        The output mode of the stateful processor.
    * @param initialState      User provided initial state that will be used to 
initiate state for
    *                          the query in the first batch.
@@ -760,11 +769,12 @@ class KeyValueGroupedDataset[K, V] private[sql](
   private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       timeoutMode: TimeoutMode,
+      ttlMode: TTLMode,
       outputMode: OutputMode,
       initialState: KeyValueGroupedDataset[K, S],
       outputEncoder: Encoder[U],
       initialStateEncoder: Encoder[S]): Dataset[U] = {
-    transformWithState(statefulProcessor, timeoutMode,
+    transformWithState(statefulProcessor, timeoutMode, ttlMode,
       outputMode, initialState)(outputEncoder, initialStateEncoder)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cc212d99f299..2c534eb36f9d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -751,7 +751,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case TransformWithState(
         keyDeserializer, valueDeserializer, groupingAttributes,
-        dataAttributes, statefulProcessor, timeoutMode, outputMode,
+        dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode,
         keyEncoder, outputAttr, child, hasInitialState,
         initialStateGroupingAttrs, initialStateDataAttrs,
         initialStateDeserializer, initialState) =>
@@ -761,6 +761,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           groupingAttributes,
           dataAttributes,
           statefulProcessor,
+          ttlMode,
           timeoutMode,
           outputMode,
           keyEncoder,
@@ -925,12 +926,12 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           hasInitialState, planLater(initialState), planLater(child)
         ) :: Nil
       case logical.TransformWithState(keyDeserializer, valueDeserializer, 
groupingAttributes,
-          dataAttributes, statefulProcessor, timeoutMode, outputMode, 
keyEncoder,
+          dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode, 
keyEncoder,
           outputObjAttr, child, hasInitialState,
           initialStateGroupingAttrs, initialStateDataAttrs,
           initialStateDeserializer, initialState) =>
         
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, 
valueDeserializer,
-          groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, 
outputMode,
+          groupingAttributes, dataAttributes, statefulProcessor, ttlMode, 
timeoutMode, outputMode,
           keyEncoder, outputObjAttr, planLater(child), hasInitialState,
           initialStateGroupingAttrs, initialStateDataAttrs,
           initialStateDeserializer, planLater(initialState)) :: Nil
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
index 662bef5716ea..56c9d2664d9e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import 
org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore, StateStoreErrors}
 import org.apache.spark.sql.streaming.ListState
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
index d2ccd0a77807..c58f32ed756d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
@@ -45,7 +45,7 @@ class MapStateImpl[K, V](
 
   /** Whether state exists or not. */
   override def exists(): Boolean = {
-    !store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).isEmpty
+    store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty
   }
 
   /** Get the state value if it exists */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index 1d41db896cdf..b2dba7668d62 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -23,11 +23,15 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer
 import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
-import org.apache.spark.sql.types.{BinaryType, StructType}
+import org.apache.spark.sql.types.{BinaryType, LongType, StructType}
 
-object StateKeyValueRowSchema {
+object TransformWithStateKeyValueRowSchema {
   val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType)
-  val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType)
+  val VALUE_ROW_SCHEMA: StructType = new StructType()
+    .add("value", BinaryType)
+  val VALUE_ROW_SCHEMA_WITH_TTL: StructType = new StructType()
+    .add("value", BinaryType)
+    .add("ttlExpirationMs", LongType)
 }
 
 /**
@@ -49,12 +53,17 @@ object StateKeyValueRowSchema {
 class StateTypesEncoder[GK, V](
     keySerializer: Serializer[GK],
     valEncoder: Encoder[V],
-    stateName: String) {
-  import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._
+    stateName: String,
+    hasTtl: Boolean) {
+  import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._
 
   /** Variables reused for conversions between byte array and UnsafeRow */
   private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA)
-  private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA)
+  private val valueProjection = if (hasTtl) {
+    UnsafeProjection.create(VALUE_ROW_SCHEMA_WITH_TTL)
+  } else {
+    UnsafeProjection.create(VALUE_ROW_SCHEMA)
+  }
 
   /** Variables reused for value conversions between spark sql and object */
   private val valExpressionEnc = encoderFor(valEncoder)
@@ -65,22 +74,47 @@ class StateTypesEncoder[GK, V](
   // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
   // add caching for some of these calls.
   def encodeGroupingKey(): UnsafeRow = {
+    val keyRow = keyProjection(InternalRow(serializeGroupingKey()))
+    keyRow
+  }
+
+  /**
+   * Encodes the provided grouping key into Spark UnsafeRow.
+   *
+   * @param groupingKeyBytes serialized grouping key byte array
+   * @return encoded UnsafeRow
+   */
+  def encodeSerializedGroupingKey(groupingKeyBytes: Array[Byte]): UnsafeRow = {
+    val keyRow = keyProjection(InternalRow(groupingKeyBytes))
+    keyRow
+  }
+
+  def serializeGroupingKey(): Array[Byte] = {
     val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
     if (keyOption.isEmpty) {
       throw StateStoreErrors.implicitKeyNotFound(stateName)
     }
-
     val groupingKey = keyOption.get.asInstanceOf[GK]
-    val keyByteArr = 
keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
-    val keyRow = keyProjection(InternalRow(keyByteArr))
-    keyRow
+    keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
   }
 
+  /**
+   * Encode the specified value in Spark UnsafeRow with no ttl.
+   */
   def encodeValue(value: V): UnsafeRow = {
     val objRow: InternalRow = objToRowSerializer.apply(value)
     val bytes = objRow.asInstanceOf[UnsafeRow].getBytes()
-    val valRow = valueProjection(InternalRow(bytes))
-    valRow
+    valueProjection(InternalRow(bytes))
+  }
+
+  /**
+   * Encode the specified value in Spark UnsafeRow
+   * with provided ttl expiration.
+   */
+  def encodeValue(value: V, expirationMs: Long): UnsafeRow = {
+    val objRow: InternalRow = objToRowSerializer.apply(value)
+    val bytes = objRow.asInstanceOf[UnsafeRow].getBytes()
+    valueProjection(InternalRow(bytes, expirationMs))
   }
 
   def decodeValue(row: UnsafeRow): V = {
@@ -89,14 +123,31 @@ class StateTypesEncoder[GK, V](
     val value = rowToObjDeserializer.apply(reusedValRow)
     value
   }
+
+  /**
+   * Decode the ttl information out of Value row. If the ttl has
+   * not been set (-1L specifies no user defined value), the API will
+   * return None.
+   */
+  def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = {
+    // ensure ttl has been set
+    assert(hasTtl)
+    val expirationMs = row.getLong(1)
+    if (expirationMs == -1) {
+      None
+    } else {
+      Some(expirationMs)
+    }
+  }
 }
 
 object StateTypesEncoder {
   def apply[GK, V](
       keySerializer: Serializer[GK],
       valEncoder: Encoder[V],
-      stateName: String): StateTypesEncoder[GK, V] = {
-    new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName)
+      stateName: String,
+      hasTtl: Boolean = false): StateTypesEncoder[GK, V] = {
+    new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, hasTtl)
   }
 }
 
@@ -105,8 +156,9 @@ class CompositeKeyStateEncoder[GK, K, V](
     userKeyEnc: Encoder[K],
     valEncoder: Encoder[V],
     schemaForCompositeKeyRow: StructType,
-    stateName: String)
-  extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) {
+    stateName: String,
+    hasTtl: Boolean = false)
+  extends StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName, 
hasTtl) {
 
   private val compositeKeyProjection = 
UnsafeProjection.create(schemaForCompositeKeyRow)
   private val reusedKeyRow = new UnsafeRow(userKeyEnc.schema.fields.length)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 5f3b794fd117..7bef62b7fcce 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.execution.streaming
 
+import java.util
 import java.util.UUID
 
 import org.apache.spark.TaskContext
@@ -24,7 +25,7 @@ import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, 
StatefulProcessorHandle, TimeoutMode, ValueState}
+import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, 
StatefulProcessorHandle, TimeoutMode, TTLConfig, TTLMode, ValueState}
 import org.apache.spark.util.Utils
 
 /**
@@ -77,14 +78,22 @@ class StatefulProcessorHandleImpl(
     store: StateStore,
     runId: UUID,
     keyEncoder: ExpressionEncoder[Any],
+    ttlMode: TTLMode,
     timeoutMode: TimeoutMode,
-    isStreaming: Boolean = true)
+    isStreaming: Boolean = true,
+    batchTimestampMs: Option[Long] = None)
   extends StatefulProcessorHandle with Logging {
   import StatefulProcessorHandleState._
 
+  /**
+   * Stores all the active ttl states, and is used to cleanup expired values
+   * in [[doTtlCleanup()]] function.
+   */
+  private[sql] val ttlStates: util.List[TTLState] = new 
util.ArrayList[TTLState]()
+
   private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"
-  private def buildQueryInfo(): QueryInfo = {
 
+  private def buildQueryInfo(): QueryInfo = {
     val taskCtxOpt = Option(TaskContext.get())
     val (queryId, batchId) = if (!isStreaming) {
       (BATCH_QUERY_ID, 0L)
@@ -103,22 +112,33 @@ class StatefulProcessorHandleImpl(
 
   private var currState: StatefulProcessorHandleState = CREATED
 
-  private def verify(condition: => Boolean, msg: String): Unit = {
-    if (!condition) {
-      throw new IllegalStateException(msg)
-    }
-  }
-
   def setHandleState(newState: StatefulProcessorHandleState): Unit = {
     currState = newState
   }
 
   def getHandleState: StatefulProcessorHandleState = currState
 
-  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T]): ValueState[T] = {
     verifyStateVarOperations("get_value_state")
-    val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, 
valEncoder)
-    resultState
+
+    new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder)
+  }
+
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T] = {
+    verifyStateVarOperations("get_value_state")
+    validateTTLConfig(ttlConfig, stateName)
+
+    assert(batchTimestampMs.isDefined)
+    val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
+      keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
+    ttlStates.add(valueStateWithTTL)
+
+    valueStateWithTTL
   }
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
@@ -185,6 +205,16 @@ class StatefulProcessorHandleImpl(
     timerState.listTimers()
   }
 
+  /**
+   * Performs the user state cleanup based on assigned TTl values. Any state
+   * which is expired will be cleaned up from StateStore.
+   */
+  def doTtlCleanup(): Unit = {
+    ttlStates.forEach { s =>
+      s.clearExpiredState()
+    }
+  }
+
   /**
    * Function to delete and purge state variable if defined previously
    *
@@ -209,4 +239,13 @@ class StatefulProcessorHandleImpl(
     val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, 
userKeyEnc, valEncoder)
     resultState
   }
+
+  private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit 
= {
+    val ttlDuration = ttlConfig.ttlDuration
+    if (ttlMode != TTLMode.ProcessingTimeTTL()) {
+      throw StateStoreErrors.cannotProvideTTLConfigForNoTTLMode(stateName)
+    } else if (ttlDuration == null || ttlDuration.isNegative || 
ttlDuration.isZero) {
+      throw StateStoreErrors.ttlMustBePositive("update", stateName)
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
new file mode 100644
index 000000000000..0ae93549b731
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import 
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, 
StateStore}
+import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType, 
StructField, StructType}
+
+object StateTTLSchema {
+  val TTL_KEY_ROW_SCHEMA: StructType = new StructType()
+    .add("expirationMs", LongType)
+    .add("groupingKey", BinaryType)
+  val TTL_VALUE_ROW_SCHEMA: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+}
+
+/**
+ * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ *
+ * @param groupingKey grouping key for which ttl is set
+ * @param expirationMs expiration time for the grouping key
+ */
+case class SingleKeyTTLRow(
+    groupingKey: Array[Byte],
+    expirationMs: Long)
+
+/**
+ * Represents the underlying state for secondary TTL Index for a user defined
+ * state variable.
+ *
+ * This state allows Spark to query ttl values based on expiration time
+ * allowing efficient ttl cleanup.
+ */
+trait TTLState {
+
+  /**
+   * Perform the user state clean up based on ttl values stored in
+   * this state. NOTE that its not safe to call this operation concurrently
+   * when the user can also modify the underlying State. Cleanup should be 
initiated
+   * after arbitrary state operations are completed by the user.
+   */
+  def clearExpiredState(): Unit
+
+  /**
+   * Clears the user state associated with this grouping key
+   * if it has expired. This function is called by Spark to perform
+   * cleanup at the end of transformWithState processing.
+   *
+   * Spark uses a secondary index to determine if the user state for
+   * this grouping key has expired. However, its possible that the user
+   * has updated the TTL and secondary index is out of date. Implementations
+   * must validate that the user State has actually expired before cleanup 
based
+   * on their own State data.
+   *
+   * @param groupingKey grouping key for which cleanup should be performed.
+   */
+  def clearIfExpired(groupingKey: Array[Byte]): Unit
+}
+
+/**
+ * Manages the ttl information for user state keyed with a single key 
(grouping key).
+ */
+abstract class SingleKeyTTLStateImpl(
+    stateName: String,
+    store: StateStore,
+    ttlExpirationMs: Long)
+  extends TTLState {
+
+  import org.apache.spark.sql.execution.streaming.StateTTLSchema._
+
+  private val ttlColumnFamilyName = s"_ttl_$stateName"
+  private val ttlKeyEncoder = UnsafeProjection.create(TTL_KEY_ROW_SCHEMA)
+
+  // empty row used for values
+  private val EMPTY_ROW =
+    
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, 
TTL_VALUE_ROW_SCHEMA,
+    RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, 1), isInternal = true)
+
+  def upsertTTLForStateKey(
+      expirationMs: Long,
+      groupingKey: Array[Byte]): Unit = {
+    val encodedTtlKey = ttlKeyEncoder(InternalRow(expirationMs, groupingKey))
+    store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName)
+  }
+
+  /**
+   * Clears any state which has ttl older than [[ttlExpirationMs]].
+   */
+  override def clearExpiredState(): Unit = {
+    val iterator = store.iterator(ttlColumnFamilyName)
+
+    iterator.takeWhile { kv =>
+      val expirationMs = kv.key.getLong(0)
+      StateTTL.isExpired(expirationMs, ttlExpirationMs)
+    }.foreach { kv =>
+      val groupingKey = kv.key.getBinary(1)
+      clearIfExpired(groupingKey)
+      store.remove(kv.key, ttlColumnFamilyName)
+    }
+  }
+
+  private[sql] def ttlIndexIterator(): Iterator[SingleKeyTTLRow] = {
+    val ttlIterator = store.iterator(ttlColumnFamilyName)
+
+    new Iterator[SingleKeyTTLRow] {
+      override def hasNext: Boolean = ttlIterator.hasNext
+
+      override def next(): SingleKeyTTLRow = {
+        val kv = ttlIterator.next()
+        SingleKeyTTLRow(
+          expirationMs = kv.key.getLong(0),
+          groupingKey = kv.key.getBinary(1)
+        )
+      }
+    }
+  }
+}
+
+/**
+ * Helper methods for user State TTL.
+ */
+object StateTTL {
+  def calculateExpirationTimeForDuration(
+      ttlDuration: Duration,
+      batchTtlExpirationMs: Long): Long = {
+    batchTtlExpirationMs + ttlDuration.toMillis
+  }
+
+  def isExpired(
+      expirationMs: Long,
+      batchTtlExpirationMs: Long): Boolean = {
+    batchTtlExpirationMs >= expirationMs
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index af321eecb4db..8d410b677c84 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -78,25 +78,25 @@ class TimerStateImpl(
 
   private val secIndexKeyEncoder = 
UnsafeProjection.create(keySchemaForSecIndex)
 
-  val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) {
+  private val timerCFName = if (timeoutMode == TimeoutMode.ProcessingTime) {
     TimerStateUtils.PROC_TIMERS_STATE_NAME
   } else {
     TimerStateUtils.EVENT_TIMERS_STATE_NAME
   }
 
-  val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
+  private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
   store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow,
     schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1),
     useMultipleValuesPerKey = false, isInternal = true)
 
-  val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
+  private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
   store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
     schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
     useMultipleValuesPerKey = false, isInternal = true)
 
   private def getGroupingKey(cfName: String): Any = {
     val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
-    if (!keyOption.isDefined) {
+    if (keyOption.isEmpty) {
       throw StateStoreErrors.implicitKeyNotFound(cfName)
     }
     keyOption.get
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 66c19fa22304..eaf51614d7cb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution._
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, 
StatefulProcessorWithInitialState, TimeoutMode}
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.streaming._
 import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
 /**
@@ -42,6 +42,7 @@ import org.apache.spark.util.{CompletionIterator, 
SerializableConfiguration, Uti
  * @param groupingAttributes used to group the data
  * @param dataAttributes used to read the data
  * @param statefulProcessor processor methods called on underlying data
+ * @param ttlMode defines the ttl Mode for user state
  * @param timeoutMode defines the timeout mode
  * @param outputMode defines the output mode for the statefulProcessor
  * @param keyEncoder expression encoder for the key type
@@ -58,6 +59,7 @@ case class TransformWithStateExec(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     statefulProcessor: StatefulProcessor[Any, Any, Any],
+    ttlMode: TTLMode,
     timeoutMode: TimeoutMode,
     outputMode: OutputMode,
     keyEncoder: ExpressionEncoder[Any],
@@ -78,17 +80,14 @@ case class TransformWithStateExec(
   override def shortName: String = "transformWithStateExec"
 
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
-    timeoutMode match {
+    if (ttlMode == TTLMode.ProcessingTimeTTL() || timeoutMode == 
TimeoutMode.ProcessingTime()) {
       // TODO: check if we can return true only if actual timers are registered
-      case ProcessingTime =>
-        true
-
-      case EventTime =>
-        eventTimeWatermarkForEviction.isDefined &&
-          newInputWatermark > eventTimeWatermarkForEviction.get
-
-      case _ =>
-        false
+      true
+    } else if (timeoutMode == TimeoutMode.EventTime()) {
+      eventTimeWatermarkForEviction.isDefined &&
+        newInputWatermark > eventTimeWatermarkForEviction.get
+    } else {
+      false
     }
   }
 
@@ -102,10 +101,6 @@ case class TransformWithStateExec(
 
   override def keyExpressions: Seq[Attribute] = groupingAttributes
 
-  protected val schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
-
-  protected val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
-
   /**
    * Distribute by grouping attributes - We need the underlying data and the 
initial state data
    * to have the same grouping so that the data are co-located on the same 
task.
@@ -284,6 +279,8 @@ case class TransformWithStateExec(
       allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
       commitTimeMs += timeTakenMs {
         if (isStreaming) {
+          // clean up any expired user state
+          processorHandle.doTtlCleanup()
           store.commit()
         } else {
           store.abort()
@@ -300,19 +297,8 @@ case class TransformWithStateExec(
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
-    timeoutMode match {
-      case ProcessingTime =>
-        if (batchTimestampMs.isEmpty) {
-          StateStoreErrors.missingTimeoutValues(timeoutMode.toString)
-        }
-
-      case EventTime =>
-        if (eventTimeWatermarkForEviction.isEmpty) {
-          StateStoreErrors.missingTimeoutValues(timeoutMode.toString)
-        }
-
-      case _ =>
-    }
+    validateTTLMode()
+    validateTimeoutMode()
 
     if (hasInitialState) {
       val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
@@ -332,9 +318,9 @@ case class TransformWithStateExec(
             val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
             val store = StateStore.get(
               storeProviderId = storeProviderId,
-              keySchema = schemaForKeyRow,
-              valueSchema = schemaForValueRow,
-              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              KEY_ROW_SCHEMA,
+              VALUE_ROW_SCHEMA,
+              NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
               version = stateInfo.get.storeVersion,
               useColumnFamilies = true,
               storeConf = storeConf,
@@ -352,9 +338,9 @@ case class TransformWithStateExec(
       if (isStreaming) {
         child.execute().mapPartitionsWithStateStore[InternalRow](
           getStateInfo,
-          schemaForKeyRow,
-          schemaForValueRow,
-          NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+          KEY_ROW_SCHEMA,
+          VALUE_ROW_SCHEMA,
+          NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
           session.sqlContext.sessionState,
           Some(session.sqlContext.streams.stateStoreCoordinator),
           useColumnFamilies = true
@@ -402,9 +388,9 @@ case class TransformWithStateExec(
     // Create StateStoreProvider for this partition
     val stateStoreProvider = StateStoreProvider.createAndInit(
       providerId,
-      schemaForKeyRow,
-      schemaForValueRow,
-      NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+      KEY_ROW_SCHEMA,
+      VALUE_ROW_SCHEMA,
+      NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
       useColumnFamilies = true,
       storeConf = storeConf,
       hadoopConf = hadoopConfBroadcast.value.value,
@@ -427,10 +413,11 @@ case class TransformWithStateExec(
   private def processData(store: StateStore, singleIterator: 
Iterator[InternalRow]):
     CompletionIterator[InternalRow, Iterator[InternalRow]] = {
     val processorHandle = new StatefulProcessorHandleImpl(
-      store, getStateInfo.queryRunId, keyEncoder, timeoutMode, isStreaming)
+      store, getStateInfo.queryRunId, keyEncoder, ttlMode, timeoutMode,
+      isStreaming, batchTimestampMs)
     assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
     statefulProcessor.setHandle(processorHandle)
-    statefulProcessor.init(outputMode, timeoutMode)
+    statefulProcessor.init(outputMode, timeoutMode, ttlMode)
     processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
     processDataWithPartition(singleIterator, store, processorHandle)
   }
@@ -441,10 +428,10 @@ case class TransformWithStateExec(
       initStateIterator: Iterator[InternalRow]):
     CompletionIterator[InternalRow, Iterator[InternalRow]] = {
     val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
-      keyEncoder, timeoutMode, isStreaming)
+      keyEncoder, ttlMode, timeoutMode, isStreaming)
     assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
     statefulProcessor.setHandle(processorHandle)
-    statefulProcessor.init(outputMode, timeoutMode)
+    statefulProcessor.init(outputMode, timeoutMode, ttlMode)
     processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
 
     // Check if is first batch
@@ -462,9 +449,36 @@ case class TransformWithStateExec(
 
     processDataWithPartition(childDataIterator, store, processorHandle)
   }
+
+  private def validateTimeoutMode(): Unit = {
+    timeoutMode match {
+      case ProcessingTime =>
+        if (batchTimestampMs.isEmpty) {
+          StateStoreErrors.missingTimeoutValues(timeoutMode.toString)
+        }
+
+      case EventTime =>
+        if (eventTimeWatermarkForEviction.isEmpty) {
+          StateStoreErrors.missingTimeoutValues(timeoutMode.toString)
+        }
+
+      case _ =>
+    }
+  }
+
+  private def validateTTLMode(): Unit = {
+    ttlMode match {
+      case ProcessingTimeTTL =>
+        if (batchTimestampMs.isEmpty) {
+          StateStoreErrors.missingTTLValues(timeoutMode.toString)
+        }
+
+      case _ =>
+    }
+  }
 }
 
-// scalastyle:off
+// scalastyle:off argcount
 object TransformWithStateExec {
 
   // Plan logical transformWithState for batch queries
@@ -474,6 +488,7 @@ object TransformWithStateExec {
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
       statefulProcessor: StatefulProcessor[Any, Any, Any],
+      ttlMode: TTLMode,
       timeoutMode: TimeoutMode,
       outputMode: OutputMode,
       keyEncoder: ExpressionEncoder[Any],
@@ -499,6 +514,7 @@ object TransformWithStateExec {
       groupingAttributes,
       dataAttributes,
       statefulProcessor,
+      ttlMode,
       timeoutMode,
       outputMode,
       keyEncoder,
@@ -516,4 +532,5 @@ object TransformWithStateExec {
       initialState)
   }
 }
-// scalastyle:on
+// scalastyle:on argcount
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index 08876ca3032e..d916011245c0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.streaming
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import 
org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore}
 import org.apache.spark.sql.streaming.ValueState
 
@@ -29,7 +28,7 @@ import org.apache.spark.sql.streaming.ValueState
  * variables used in the streaming transformWithState operator.
  * @param store - reference to the StateStore instance to be used for storing 
state
  * @param stateName - name of logical state partition
- * @param keyEnc - Spark SQL encoder for key
+ * @param keyExprEnc - Spark SQL encoder for key
  * @param valEncoder - Spark SQL encoder for value
  * @tparam S - data type of object that will be stored
  */
@@ -37,18 +36,22 @@ class ValueStateImpl[S](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    valEncoder: Encoder[S]) extends ValueState[S] with Logging {
+    valEncoder: Encoder[S])
+  extends ValueState[S] with Logging {
 
   private val keySerializer = keyExprEnc.createSerializer()
-
   private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, 
stateName)
 
-  store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
-    NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+  initialize()
+
+  private def initialize(): Unit = {
+    store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
+      NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+  }
 
   /** Function to check if state exists. Returns true if present and false 
otherwise */
   override def exists(): Boolean = {
-    getImpl() != null
+    get() != null
   }
 
   /** Function to return Option of value if exists and None otherwise */
@@ -58,7 +61,9 @@ class ValueStateImpl[S](
 
   /** Function to return associated value with key if exists and null 
otherwise */
   override def get(): S = {
-    val retRow = getImpl()
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val retRow = store.get(encodedGroupingKey, stateName)
+
     if (retRow != null) {
       stateTypesEncoder.decodeValue(retRow)
     } else {
@@ -66,14 +71,12 @@ class ValueStateImpl[S](
     }
   }
 
-  private def getImpl(): UnsafeRow = {
-    store.get(stateTypesEncoder.encodeGroupingKey(), stateName)
-  }
-
   /** Function to update and overwrite state associated with given key */
   override def update(newState: S): Unit = {
-    store.put(stateTypesEncoder.encodeGroupingKey(),
-      stateTypesEncoder.encodeValue(newState), stateName)
+    val encodedValue = stateTypesEncoder.encodeValue(newState)
+    val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+      encodedValue, stateName)
   }
 
   /** Function to remove state for given key */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
new file mode 100644
index 000000000000..d3c9eb9de204
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore}
+import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state 
associated with state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlConfig  - TTL configuration for values  stored in this state
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+    store: StateStore,
+    stateName: String,
+    keyExprEnc: ExpressionEncoder[Any],
+    valEncoder: Encoder[S],
+    ttlConfig: TTLConfig,
+    batchTimestampMs: Long)
+  extends SingleKeyTTLStateImpl(stateName, store, batchTimestampMs) with 
ValueState[S] {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+    stateName, hasTtl = true)
+  private val ttlExpirationMs =
+    StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+    store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+      NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+  }
+
+  /** Function to check if state exists. Returns true if present and false 
otherwise */
+  override def exists(): Boolean = {
+    get() != null
+  }
+
+  /** Function to return Option of value if exists and None otherwise */
+  override def getOption(): Option[S] = {
+    Option(get())
+  }
+
+  /** Function to return associated value with key if exists and null 
otherwise */
+  override def get(): S = {
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val retRow = store.get(encodedGroupingKey, stateName)
+
+    if (retRow != null) {
+      val resState = stateTypesEncoder.decodeValue(retRow)
+
+      if (!isExpired(retRow)) {
+        resState
+      } else {
+        null.asInstanceOf[S]
+      }
+    } else {
+      null.asInstanceOf[S]
+    }
+  }
+
+  /** Function to update and overwrite state associated with given key */
+  override def update(newState: S): Unit = {
+    val encodedValue = stateTypesEncoder.encodeValue(newState, ttlExpirationMs)
+    val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+      encodedValue, stateName)
+    upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey)
+  }
+
+  /** Function to remove state for given key */
+  override def clear(): Unit = {
+    store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+  }
+
+  def clearIfExpired(groupingKey: Array[Byte]): Unit = {
+    val encodedGroupingKey = 
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
+    val retRow = store.get(encodedGroupingKey, stateName)
+
+    if (retRow != null) {
+      if (isExpired(retRow)) {
+        store.remove(encodedGroupingKey, stateName)
+      }
+    }
+  }
+
+  private def isExpired(valueRow: UnsafeRow): Boolean = {
+    val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
+    expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs))
+  }
+
+  /*
+   * Internal methods to probe state for testing. The below methods exist for 
unit tests
+   * to read the state ttl values, and ensure that values are persisted 
correctly in
+   * the underlying state store.
+   */
+
+  /**
+   * Retrieves the value from State even if its expired. This method is used
+   * in tests to read the state store value, and ensure if its cleaned up at 
the
+   * end of the micro-batch.
+   */
+  private[sql] def getWithoutEnforcingTTL(): Option[S] = {
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val retRow = store.get(encodedGroupingKey, stateName)
+
+    if (retRow != null) {
+      val resState = stateTypesEncoder.decodeValue(retRow)
+      Some(resState)
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Read the ttl value associated with the grouping key.
+   */
+  private[sql] def getTTLValue(): Option[Long] = {
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val retRow = store.get(encodedGroupingKey, stateName)
+
+    if (retRow != null) {
+      stateTypesEncoder.decodeTtlExpirationMs(retRow)
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Get all ttl values stored in ttl state for current implicit
+   * grouping key.
+   */
+  private[sql] def getValuesInTTLState(): Iterator[Long] = {
+    val ttlIterator = ttlIndexIterator()
+    val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    var nextValue: Option[Long] = None
+
+    new Iterator[Long] {
+      override def hasNext: Boolean = {
+        while (nextValue.isEmpty && ttlIterator.hasNext) {
+          val nextTtlValue = ttlIterator.next()
+          val groupingKey = nextTtlValue.groupingKey
+          if (groupingKey sameElements implicitGroupingKey) {
+            nextValue = Some(nextTtlValue.expirationMs)
+          }
+        }
+        nextValue.isDefined
+      }
+
+      override def next(): Long = {
+        val result = nextValue.get
+        nextValue = None
+        result
+      }
+    }
+  }
+}
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index 2f72cbb0b0fc..6c63aa94e75b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -39,6 +39,13 @@ object StateStoreErrors {
     )
   }
 
+  def missingTTLValues(ttlMode: String): SparkException = {
+    SparkException.internalError(
+      msg = s"Failed to find timeout values for ttlMode=$ttlMode",
+      category = "TWS"
+    )
+  }
+
   def unsupportedOperationOnMissingColumnFamily(operationName: String, 
colFamilyName: String):
     StateStoreUnsupportedOperationOnMissingColumnFamily = {
     new StateStoreUnsupportedOperationOnMissingColumnFamily(operationName, 
colFamilyName)
@@ -117,6 +124,16 @@ object StateStoreErrors {
     StatefulProcessorCannotReInitializeState = {
     new StatefulProcessorCannotReInitializeState(groupingKey)
   }
+
+  def cannotProvideTTLConfigForNoTTLMode(stateName: String):
+    StatefulProcessorCannotAssignTTLInNoTTLMode = {
+    new StatefulProcessorCannotAssignTTLInNoTTLMode(stateName)
+  }
+
+  def ttlMustBePositive(operationType: String,
+      stateName: String): StatefulProcessorTTLMustBePositive = {
+    new StatefulProcessorTTLMustBePositive(operationType, stateName)
+  }
 }
 
 class 
StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: 
String)
@@ -192,3 +209,15 @@ class 
StateStoreNullTypeOrderingColsNotSupported(fieldName: String, index: Strin
   extends SparkUnsupportedOperationException(
     errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED",
     messageParameters = Map("fieldName" -> fieldName, "index" -> index))
+
+class StatefulProcessorCannotAssignTTLInNoTTLMode(stateName: String)
+  extends SparkUnsupportedOperationException(
+    errorClass = "STATEFUL_PROCESSOR_CANNOT_ASSIGN_TTL_IN_NO_TTL_MODE",
+    messageParameters = Map("stateName" -> stateName))
+
+class StatefulProcessorTTLMustBePositive(
+    operationType: String,
+    stateName: String)
+  extends SparkUnsupportedOperationException(
+    errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE",
+    messageParameters = Map("operationType" -> operationType, "stateName" -> 
stateName))
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 02927f1d962f..f9f075f4468d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -207,6 +207,7 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<String> transformWithStateMapped = grouped.transformWithState(
       new TestStatefulProcessorWithInitialState(),
       TimeoutMode.NoTimeouts(),
+      TTLMode.NoTTL(),
       OutputMode.Append(),
       kvInitStateMappedDS,
       Encoders.STRING(),
@@ -362,6 +363,7 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<String> transformWithStateMapped = grouped.transformWithState(
       testStatefulProcessor,
       TimeoutMode.NoTimeouts(),
+      TTLMode.NoTTL(),
       OutputMode.Append(),
       Encoders.STRING());
 
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
index 3122e0e337a3..c6d705af5f2d 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
@@ -36,7 +36,10 @@ public class TestStatefulProcessor extends 
StatefulProcessor<Integer, String, St
   private transient ListState<String> keysList;
 
   @Override
-  public void init(OutputMode outputMode, TimeoutMode timeoutMode) {
+  public void init(
+      OutputMode outputMode,
+      TimeoutMode timeoutMode,
+      TTLMode ttlMode) {
     countState = this.getHandle().getValueState("countState",
       Encoders.LONG());
 
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
index 247bae3a3f3c..db0b222145c4 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
@@ -35,7 +35,10 @@ public class TestStatefulProcessorWithInitialState
   private transient ValueState<String> testState;
 
   @Override
-  public void init(OutputMode outputMode, TimeoutMode timeoutMode) {
+  public void init(
+      OutputMode outputMode,
+      TimeoutMode timeoutMode,
+      TTLMode ttlMode) {
     testState = this.getHandle().getValueState("testState",
       Encoders.STRING());
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index e895e475b74d..51cfc1548b39 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
-import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState}
+import org.apache.spark.sql.streaming.{ListState, TimeoutMode, TTLMode, 
ValueState}
 
 /**
  * Class that adds unit tests for ListState types used in arbitrary stateful
@@ -37,7 +37,8 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val listState: ListState[Long] = handle.getListState[Long]("listState", 
Encoders.scalaLong)
 
@@ -47,7 +48,7 @@ class ListStateSuite extends StateVariableSuiteBase {
       }
 
       checkError(
-        exception = e.asInstanceOf[SparkIllegalArgumentException],
+        exception = e,
         errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE",
         sqlState = Some("42601"),
         parameters = Map("stateName" -> "listState")
@@ -70,7 +71,8 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ListState[Long] = handle.getListState[Long]("testState", 
Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -98,7 +100,8 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState1: ListState[Long] = 
handle.getListState[Long]("testState1", Encoders.scalaLong)
       val testState2: ListState[Long] = 
handle.getListState[Long]("testState2", Encoders.scalaLong)
@@ -136,7 +139,8 @@ class ListStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val listState1: ListState[Long] = 
handle.getListState[Long]("listState1", Encoders.scalaLong)
       val listState2: ListState[Long] = 
handle.getListState[Long]("listState2", Encoders.scalaLong)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index ce72061d39ea..7fa41b12795e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -22,7 +22,7 @@ import java.util.UUID
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
-import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, 
ValueState}
+import org.apache.spark.sql.streaming.{ListState, MapState, TimeoutMode, 
TTLMode, ValueState}
 import org.apache.spark.sql.types.{BinaryType, StructType}
 
 /**
@@ -39,7 +39,8 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: MapState[String, Double] =
         handle.getMapState[String, Double]("testState", Encoders.STRING, 
Encoders.scalaDouble)
@@ -73,7 +74,8 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState1: MapState[Long, Double] =
         handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, 
Encoders.scalaDouble)
@@ -112,7 +114,8 @@ class MapStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val mapTestState1: MapState[String, Int] =
         handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, 
Encoders.scalaInt)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index 662a5dbfaac4..a32b4111eae8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.time.Duration
 import java.util.UUID
 
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
-import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.streaming.{TimeoutMode, TTLConfig, TTLMode}
+
 
 /**
  * Class that adds tests to verify operations based on stateful processor 
handle
@@ -48,7 +50,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
+          UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
getTimeoutMode(timeoutMode))
         assert(handle.getHandleState === StatefulProcessorHandleState.CREATED)
         handle.getValueState[Long]("testState", Encoders.scalaLong)
       }
@@ -89,7 +91,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
+          UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
getTimeoutMode(timeoutMode))
 
         Seq(StatefulProcessorHandleState.INITIALIZED,
           StatefulProcessorHandleState.DATA_PROCESSED,
@@ -107,7 +109,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
-        UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts())
+        UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
TimeoutMode.NoTimeouts())
       val ex = intercept[SparkUnsupportedOperationException] {
         handle.registerTimer(10000L)
       }
@@ -143,7 +145,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
+          UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
getTimeoutMode(timeoutMode))
         handle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
         assert(handle.getHandleState === 
StatefulProcessorHandleState.INITIALIZED)
 
@@ -164,7 +166,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
+          UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
getTimeoutMode(timeoutMode))
         handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
         assert(handle.getHandleState === 
StatefulProcessorHandleState.DATA_PROCESSED)
 
@@ -204,7 +206,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
-          UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
+          UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
getTimeoutMode(timeoutMode))
 
         Seq(StatefulProcessorHandleState.CREATED,
           StatefulProcessorHandleState.TIMER_PROCESSED,
@@ -216,4 +218,34 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       }
     }
   }
+
+  test(s"ttl States are populated for ttlMode=ProcessingTime") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store,
+        UUID.randomUUID(), keyExprEncoder, TTLMode.ProcessingTimeTTL(), 
TimeoutMode.NoTimeouts(),
+        batchTimestampMs = Some(10))
+
+      val valueStateWithTTL = handle.getValueState("testState",
+        Encoders.STRING, TTLConfig(Duration.ofHours(1)))
+
+      // create another state without TTL, this should not be captured in the 
handle
+      handle.getValueState("testState", Encoders.STRING)
+
+      assert(handle.ttlStates.size() === 1)
+      assert(handle.ttlStates.get(0) === valueStateWithTTL)
+    }
+  }
+
+  test(s"ttl States are not populated for ttlMode=NoTTL") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store,
+        UUID.randomUUID(), keyExprEncoder, TTLMode.NoTTL(), 
TimeoutMode.NoTimeouts())
+
+      handle.getValueState("testState", Encoders.STRING)
+
+      assert(handle.ttlStates.isEmpty)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 8668b58672c7..102164d9c15f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.time.Duration
 import java.util.UUID
 
 import scala.util.Random
@@ -27,9 +28,9 @@ import org.scalatest.BeforeAndAfter
 import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, ValueStateImplWithTTL}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{TimeoutMode, ValueState}
+import org.apache.spark.sql.streaming.{TimeoutMode, TTLConfig, TTLMode, 
ValueState}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 
@@ -48,7 +49,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val stateName = "testState"
       val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
@@ -78,7 +80,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         testState.update(123)
       }
       checkError(
-        ex.asInstanceOf[SparkException],
+        ex1.asInstanceOf[SparkException],
         errorClass = "INTERNAL_ERROR_TWS",
         parameters = Map(
           "message" -> s"Implicit key not found in state store for 
stateName=$stateName"
@@ -92,7 +94,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -118,7 +121,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState1: ValueState[Long] = handle.getValueState[Long](
         "testState1", Encoders.scalaLong)
@@ -164,7 +168,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
         UUID.randomUUID(), 
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
-        TimeoutMode.NoTimeouts())
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val cfName = "_testState"
       val ex = intercept[SparkUnsupportedOperationException] {
@@ -204,7 +208,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ValueState[Double] = 
handle.getValueState[Double]("testState",
         Encoders.scalaDouble)
@@ -230,7 +235,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ValueState[Long] = handle.getValueState[Long]("testState",
         Encoders.scalaLong)
@@ -256,7 +262,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ValueState[TestClass] = 
handle.getValueState[TestClass]("testState",
         Encoders.product[TestClass])
@@ -282,7 +289,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
-        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.NoTTL(), TimeoutMode.NoTimeouts())
 
       val testState: ValueState[POJOTestClass] = 
handle.getValueState[POJOTestClass]("testState",
         Encoders.bean(classOf[POJOTestClass]))
@@ -303,6 +311,93 @@ class ValueStateSuite extends StateVariableSuiteBase {
       assert(testState.get() === null)
     }
   }
+
+
+  test(s"test Value state TTL") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val timestampMs = 10
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(),
+        batchTimestampMs = Some(timestampMs))
+
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val testState: ValueStateImplWithTTL[String] = 
handle.getValueState[String]("testState",
+        Encoders.STRING, ttlConfig).asInstanceOf[ValueStateImplWithTTL[String]]
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update("v1")
+      assert(testState.get() === "v1")
+      assert(testState.getWithoutEnforcingTTL().get === "v1")
+
+      val ttlExpirationMs = timestampMs + 60000
+      var ttlValue = testState.getTTLValue()
+      assert(ttlValue.isDefined)
+      assert(ttlValue.get === ttlExpirationMs)
+      var ttlStateValueIterator = testState.getValuesInTTLState()
+      assert(ttlStateValueIterator.hasNext)
+
+      // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
+      val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(),
+        batchTimestampMs = Some(ttlExpirationMs))
+
+      val nextBatchTestState: ValueStateImplWithTTL[String] =
+        nextBatchHandle.getValueState[String]("testState", Encoders.STRING, 
ttlConfig)
+          .asInstanceOf[ValueStateImplWithTTL[String]]
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+
+      // ensure get does not return the expired value
+      assert(!nextBatchTestState.exists())
+      assert(nextBatchTestState.get() === null)
+
+      // ttl value should still exist in state
+      ttlValue = nextBatchTestState.getTTLValue()
+      assert(ttlValue.isDefined)
+      assert(ttlValue.get === ttlExpirationMs)
+      ttlStateValueIterator = nextBatchTestState.getValuesInTTLState()
+      assert(ttlStateValueIterator.hasNext)
+      assert(ttlStateValueIterator.next() === ttlExpirationMs)
+      assert(ttlStateValueIterator.isEmpty)
+
+      // getWithoutTTL should still return the expired value
+      assert(nextBatchTestState.getWithoutEnforcingTTL().get === "v1")
+
+      nextBatchTestState.clear()
+      assert(!nextBatchTestState.exists())
+      assert(nextBatchTestState.get() === null)
+    }
+  }
+
+  test("test negative or zero TTL duration throws error") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val batchTimestampMs = 10
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(),
+        batchTimestampMs = Some(batchTimestampMs))
+
+      Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+        val ttlConfig = TTLConfig(ttlDuration)
+        val ex = intercept[SparkUnsupportedOperationException] {
+          handle.getValueState[String]("testState", Encoders.STRING, ttlConfig)
+        }
+
+        checkError(
+          ex,
+          errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE",
+          parameters = Map(
+            "operationType" -> "update",
+            "stateName" -> "testState"
+          ),
+          matchPVals = true
+        )
+      }
+    }
+  }
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index 95ab34d40131..5ccc14ab8a77 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -32,7 +32,8 @@ class TestListStateProcessor
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _listState = getHandle.getListState("testListState", Encoders.STRING)
   }
 
@@ -89,7 +90,8 @@ class ToggleSaveAndEmitProcessor
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _listState = getHandle.getListState("testListState", Encoders.STRING)
     _valueState = getHandle.getValueState("testValueState", 
Encoders.scalaBoolean)
   }
@@ -140,6 +142,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update()) (
@@ -160,6 +163,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -180,6 +184,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -200,6 +205,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -220,6 +226,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -240,6 +247,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -260,6 +268,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestListStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update()) (
@@ -312,6 +321,7 @@ class TransformWithListStateSuite extends StreamTest
         .groupByKey(x => x)
         .transformWithState(new ToggleSaveAndEmitProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index db8cb8b810af..d32b9687d95f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -32,7 +32,8 @@ class TestMapStateProcessor
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _mapState = getHandle.getMapState("sessionState", Encoders.STRING, 
Encoders.STRING)
   }
 
@@ -95,6 +96,7 @@ class TransformWithMapStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestMapStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
 
@@ -121,6 +123,7 @@ class TransformWithMapStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestMapStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -145,6 +148,7 @@ class TransformWithMapStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestMapStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -168,6 +172,7 @@ class TransformWithMapStateSuite extends StreamTest
         .groupByKey(x => x.key)
         .transformWithState(new TestMapStateProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Append())
       testStream(result, OutputMode.Append())(
         // Test exists()
@@ -222,6 +227,7 @@ class TransformWithMapStateSuite extends StreamTest
       .groupByKey(x => x.key)
       .transformWithState(new TestMapStateProcessor(),
         TimeoutMode.NoTimeouts(),
+        TTLMode.NoTTL(),
         OutputMode.Append())
 
     val df = result.toDF()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 147a13251044..106f228ba78b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -36,7 +36,10 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
   @transient var _listState: ListState[Double] = _
   @transient var _mapState: MapState[Double, Int] = _
 
-  override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = {
+  override def init(
+      outputMode: OutputMode,
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _valState = getHandle.getValueState[Double]("testValueInit", 
Encoders.scalaDouble)
     _listState = getHandle.getListState[Double]("testListInit", 
Encoders.scalaDouble)
     _mapState = getHandle.getMapState[Double, Int](
@@ -168,7 +171,8 @@ class StatefulProcessorWithInitialStateProcTimerClass
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode) : Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode) : Unit = {
     _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
     _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
   }
@@ -211,7 +215,8 @@ class StatefulProcessorWithInitialStateEventTimerClass
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
       Encoders.scalaLong)
     _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
@@ -288,7 +293,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
           InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 -> 
1)))
           .toDS().groupByKey(x => x.key).mapValues(x => x)
       val query = kvDataSet.transformWithState(new 
InitialStateInMemoryTestClass(),
-            TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf)
+        TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), 
initStateDf)
 
       testStream(query, OutputMode.Update())(
         // non-exist key test
@@ -366,7 +371,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       val query = inputData.toDS()
         .groupByKey(x => x.key)
         .transformWithState(new AccumulateStatefulProcessorWithInitState(),
-          TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf
+          TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), 
initStateDf
         )
       testStream(query, OutputMode.Update())(
         AddData(inputData, InitInputRow("init_1", "add", 50.0)),
@@ -387,6 +392,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       .groupByKey(x => x.key)
       .transformWithState(new AccumulateStatefulProcessorWithInitState(),
         TimeoutMode.NoTimeouts(),
+        TTLMode.NoTTL(),
         OutputMode.Append(),
         createInitialDfForTest)
 
@@ -405,6 +411,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
         .groupByKey(x => x.key)
         .transformWithState(new AccumulateStatefulProcessorWithInitState(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Append(),
           initDf)
 
@@ -437,6 +444,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
         .transformWithState(
           new StatefulProcessorWithInitialStateProcTimerClass(),
           TimeoutMode.ProcessingTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update(),
           initDf)
 
@@ -481,6 +489,7 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
         .transformWithState(
           new StatefulProcessorWithInitialStateEventTimerClass(),
           TimeoutMode.EventTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update(),
           initDf)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 2fd1eac179da..735c53bf3c91 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -40,7 +40,8 @@ class RunningCountStatefulProcessor extends 
StatefulProcessor[String, String, (S
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
   }
 
@@ -103,8 +104,9 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode) : Unit = {
-    super.init(outputMode, timeoutMode)
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode) : Unit = {
+    super.init(outputMode, timeoutMode, ttlMode)
     _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
   }
 
@@ -194,7 +196,8 @@ class MaxEventTimeStatefulProcessor
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
       Encoders.scalaLong)
     _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
@@ -239,10 +242,12 @@ class RunningCountMostRecentStatefulProcessor
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
     _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
   }
+
   override def handleInputRows(
       key: String,
       inputRows: Iterator[(String, String)],
@@ -268,7 +273,8 @@ class MostRecentStatefulProcessorWithDeletion
 
   override def init(
       outputMode: OutputMode,
-      timeoutMode: TimeoutMode): Unit = {
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
     getHandle.deleteIfExists("countState")
     _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
   }
@@ -322,6 +328,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new RunningCountStatefulProcessorWithError(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -343,6 +350,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new RunningCountStatefulProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -373,6 +381,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
           TimeoutMode.ProcessingTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -416,6 +425,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .transformWithState(
           new RunningCountStatefulProcessorWithProcTimeTimerUpdates(),
           TimeoutMode.ProcessingTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -452,6 +462,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .transformWithState(
           new RunningCountStatefulProcessorWithMultipleTimers(),
           TimeoutMode.ProcessingTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -487,6 +498,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .transformWithState(
           new MaxEventTimeStatefulProcessor(),
           TimeoutMode.EventTime(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
     testStream(result, OutputMode.Update())(
@@ -528,6 +540,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
       .groupByKey(x => x)
       .transformWithState(new RunningCountStatefulProcessor(),
         TimeoutMode.NoTimeouts(),
+        TTLMode.NoTTL(),
         OutputMode.Append())
 
     val df = result.toDF()
@@ -546,12 +559,14 @@ class TransformWithStateSuite extends 
StateStoreMetricsTest
           .groupByKey(x => x._1)
           .transformWithState(new RunningCountMostRecentStatefulProcessor(),
             TimeoutMode.NoTimeouts(),
+            TTLMode.NoTTL(),
             OutputMode.Update())
 
         val stream2 = inputData.toDS()
           .groupByKey(x => x._1)
           .transformWithState(new MostRecentStatefulProcessorWithDeletion(),
             TimeoutMode.NoTimeouts(),
+            TTLMode.NoTTL(),
             OutputMode.Update())
 
         testStream(stream1, OutputMode.Update())(
@@ -584,6 +599,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new RunningCountStatefulProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -617,6 +633,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new RunningCountStatefulProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -650,6 +667,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
         .groupByKey(x => x)
         .transformWithState(new RunningCountStatefulProcessor(),
           TimeoutMode.NoTimeouts(),
+          TTLMode.NoTTL(),
           OutputMode.Update())
 
       testStream(result, OutputMode.Update())(
@@ -680,6 +698,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
       .groupByKey(x => x)
       .transformWithState(new RunningCountStatefulProcessor(),
         TimeoutMode.NoTimeouts(),
+        TTLMode.NoTTL(),
         OutputMode.Update())
   }
 
@@ -772,6 +791,7 @@ class TransformWithStateValidationSuite extends 
StateStoreMetricsTest {
       .groupByKey(x => x)
       .transformWithState(new RunningCountStatefulProcessor(),
         TimeoutMode.NoTimeouts(),
+        TTLMode.NoTTL(),
         OutputMode.Update())
 
     testStream(result, OutputMode.Update())(
@@ -790,7 +810,7 @@ class TransformWithStateValidationSuite extends 
StateStoreMetricsTest {
     val result = inputData.toDS()
       .groupByKey(x => x.key)
       .transformWithState(new AccumulateStatefulProcessorWithInitState(),
-        TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf
+        TimeoutMode.NoTimeouts(), TTLMode.NoTTL(), OutputMode.Append(), initDf
       )
     testStream(result, OutputMode.Update())(
       AddData(inputData, InitInputRow("a", "add", -1.0)),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
new file mode 100644
index 000000000000..759d535c18a3
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -0,0 +1,471 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.sql.Timestamp
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, 
ValueStateImplWithTTL}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+case class InputEvent(
+    key: String,
+    action: String,
+    value: Int,
+    eventTime: Timestamp = null)
+
+case class OutputEvent(
+    key: String,
+    value: Int,
+    isTTLValue: Boolean,
+    ttlValue: Long)
+
+object TTLInputProcessFunction {
+  def processRow(
+      row: InputEvent,
+      valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = {
+    var results = List[OutputEvent]()
+    val key = row.key
+    if (row.action == "get") {
+      val currState = valueState.getOption()
+      if (currState.isDefined) {
+        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+      }
+    } else if (row.action == "get_without_enforcing_ttl") {
+      val currState = valueState.getWithoutEnforcingTTL()
+      if (currState.isDefined) {
+        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+      }
+    } else if (row.action == "get_ttl_value_from_state") {
+      val ttlExpiration = valueState.getTTLValue()
+      if (ttlExpiration.isDefined) {
+        results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) 
:: results
+      }
+    } else if (row.action == "put") {
+      valueState.update(row.value)
+    } else if (row.action == "get_values_in_ttl_state") {
+      val ttlValues = valueState.getValuesInTTLState()
+      ttlValues.foreach { v =>
+        results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+      }
+    }
+
+    results.iterator
+  }
+
+  def processNonTTLStateRow(
+      row: InputEvent,
+      valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = {
+    var results = List[OutputEvent]()
+    val key = row.key
+    if (row.action == "get") {
+      val currState = valueState.getOption()
+      if (currState.isDefined) {
+        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+      }
+    } else if (row.action == "put") {
+      valueState.update(row.value)
+    }
+
+    results.iterator
+  }
+}
+
+class ValueStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+  with Logging {
+
+  @transient private var _valueState: ValueStateImplWithTTL[Int] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
+    _valueState = getHandle
+      .getValueState("valueState", Encoders.scalaInt, ttlConfig)
+      .asInstanceOf[ValueStateImplWithTTL[Int]]
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputEvent],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+    var results = List[OutputEvent]()
+
+    inputRows.foreach { row =>
+      val resultIter = TTLInputProcessFunction.processRow(row, _valueState)
+      resultIter.foreach { r =>
+        results = r :: results
+      }
+    }
+
+    results.iterator
+  }
+}
+
+case class MultipleValueStatesTTLProcessor(
+    ttlKey: String,
+    noTtlKey: String,
+    ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+    with Logging {
+
+  @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _
+  @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeoutMode: TimeoutMode,
+      ttlMode: TTLMode): Unit = {
+    _valueStateWithTTL = getHandle
+      .getValueState("valueState", Encoders.scalaInt, ttlConfig)
+      .asInstanceOf[ValueStateImplWithTTL[Int]]
+    _valueStateWithoutTTL = getHandle
+      .getValueState("valueState", Encoders.scalaInt)
+      .asInstanceOf[ValueStateImpl[Int]]
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputEvent],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+    var results = List[OutputEvent]()
+
+    if (key == ttlKey) {
+      inputRows.foreach { row =>
+        val resultIterator = TTLInputProcessFunction.processRow(row, 
_valueStateWithTTL)
+        resultIterator.foreach { r =>
+          results = r :: results
+        }
+      }
+    } else {
+      inputRows.foreach { row =>
+        val resultIterator = TTLInputProcessFunction.processNonTTLStateRow(row,
+          _valueStateWithoutTTL)
+        resultIterator.foreach { r =>
+          results = r :: results
+        }
+      }
+    }
+
+    results.iterator
+  }
+}
+
+/**
+ * Tests that ttl works as expected for Value State for
+ * processing time and event time based ttl.
+ */
+class TransformWithValueStateTTLSuite
+  extends StreamTest {
+  import testImplicits._
+
+  test("validate state is evicted at ttl expiry") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      withTempDir { dir =>
+        val inputStream = MemoryStream[InputEvent]
+        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+        val result = inputStream.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(
+            new ValueStateTTLProcessor(ttlConfig),
+            TimeoutMode.NoTimeouts(),
+            TTLMode.ProcessingTimeTTL(),
+            OutputMode.Append())
+
+        val clock = new StreamManualClock
+        testStream(result)(
+          StartStream(
+            Trigger.ProcessingTime("1 second"),
+            triggerClock = clock,
+            checkpointLocation = dir.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "put", 1)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          StopStream,
+          StartStream(
+            Trigger.ProcessingTime("1 second"),
+            triggerClock = clock,
+            checkpointLocation = dir.getAbsolutePath),
+          // get this state, and make sure we get unexpired value
+          AddData(inputStream, InputEvent("k1", "get", -1)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
+          StopStream,
+          StartStream(
+            Trigger.ProcessingTime("1 second"),
+            triggerClock = clock,
+            checkpointLocation = dir.getAbsolutePath),
+          // ensure ttl values were added correctly
+          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+          StopStream,
+          StartStream(
+            Trigger.ProcessingTime("1 second"),
+            triggerClock = clock,
+            checkpointLocation = dir.getAbsolutePath),
+          // advance clock so that state expires
+          AdvanceManualClock(60 * 1000),
+          AddData(inputStream, InputEvent("k1", "get", -1, null)),
+          AdvanceManualClock(1 * 1000),
+          // validate expired value is not returned
+          CheckNewAnswer(),
+          // ensure this state does not exist any longer in State
+          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer()
+        )
+      }
+    }
+  }
+
+  test("validate state update updates the expiration timestamp") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val inputStream = MemoryStream[InputEvent]
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val result = inputStream.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(
+          new ValueStateTTLProcessor(ttlConfig),
+          TimeoutMode.NoTimeouts(),
+          TTLMode.ProcessingTimeTTL(),
+          OutputMode.Append())
+
+      val clock = new StreamManualClock
+      testStream(result)(
+        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+        AddData(inputStream, InputEvent("k1", "put", 1)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // get this state, and make sure we get unexpired value
+        AddData(inputStream, InputEvent("k1", "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
+        // ensure ttl values were added correctly
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        // advance clock and update expiration time
+        AdvanceManualClock(30 * 1000),
+        AddData(inputStream, InputEvent("k1", "put", 1)),
+        AddData(inputStream, InputEvent("k1", "get", -1)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        // validate value is not expired
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
+        // validate ttl value is updated in the state
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)),
+        // validate ttl state has both ttl values present
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000),
+          OutputEvent("k1", -1, isTTLValue = true, 95000)
+        ),
+        // advance clock after older expiration value
+        AdvanceManualClock(30 * 1000),
+        // ensure unexpired value is still present in the state
+        AddData(inputStream, InputEvent("k1", "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
+        // validate that the older expiration value is removed from ttl state
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000))
+      )
+    }
+  }
+
+  test("validate state is evicted at ttl expiry for no data batch") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+    classOf[RocksDBStateStoreProvider].getName) {
+      val inputStream = MemoryStream[InputEvent]
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val result = inputStream.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(
+          new ValueStateTTLProcessor(ttlConfig),
+          TimeoutMode.NoTimeouts(),
+          TTLMode.ProcessingTimeTTL(),
+          OutputMode.Append())
+
+      val clock = new StreamManualClock
+      testStream(result)(
+        StartStream(
+          Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputStream, InputEvent("k1", "put", 1)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // get this state, and make sure we get unexpired value
+        AddData(inputStream, InputEvent("k1", "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
+        // ensure ttl values were added correctly
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        // advance clock so that state expires
+        AdvanceManualClock(60 * 1000),
+        // run a no data batch
+        CheckNewAnswer(),
+        AddData(inputStream, InputEvent("k1", "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        // validate expired value is not returned
+        CheckNewAnswer(),
+        // ensure this state does not exist any longer in State
+        AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer()
+      )
+    }
+  }
+
+  test("validate multiple value states") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      val ttlKey = "k1"
+      val noTtlKey = "k2"
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val inputStream = MemoryStream[InputEvent]
+      val result = inputStream.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(
+          MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig),
+          TimeoutMode.NoTimeouts(),
+          TTLMode.ProcessingTimeTTL(),
+          OutputMode.Append())
+
+      val clock = new StreamManualClock
+      testStream(result)(
+        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+        AddData(inputStream, InputEvent(ttlKey, "put", 1)),
+        AddData(inputStream, InputEvent(noTtlKey, "put", 2)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // get both state values, and make sure we get unexpired value
+        AddData(inputStream, InputEvent(ttlKey, "get", -1)),
+        AddData(inputStream, InputEvent(noTtlKey, "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent(ttlKey, 1, isTTLValue = false, -1),
+          OutputEvent(noTtlKey, 2, isTTLValue = false, -1)
+        ),
+        // ensure ttl values were added correctly, and noTtlKey has no ttl 
values
+        AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", 
-1)),
+        AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", 
-1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)),
+        AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", 
-1)),
+        AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", 
-1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)),
+        // advance clock after expiry
+        AdvanceManualClock(60 * 1000),
+        AddData(inputStream, InputEvent(ttlKey, "get", -1)),
+        AddData(inputStream, InputEvent(noTtlKey, "get", -1)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        // validate ttlKey is expired, bot noTtlKey is still present
+        CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)),
+        // validate ttl value is removed in the value state column family
+        AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", 
-1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer()
+      )
+    }
+  }
+
+  test("validate only expired keys are removed from the state") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+      val inputStream = MemoryStream[InputEvent]
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val result = inputStream.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(
+          new ValueStateTTLProcessor(ttlConfig),
+          TimeoutMode.NoTimeouts(),
+          TTLMode.ProcessingTimeTTL(),
+          OutputMode.Append())
+
+      val clock = new StreamManualClock
+      testStream(result)(
+        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+        AddData(inputStream, InputEvent("k1", "put", 1)),
+        // advance clock to trigger processing
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // advance clock halfway to expiration ttl, and add another key
+        AdvanceManualClock(30 * 1000),
+        AddData(inputStream, InputEvent("k2", "put", 2)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // advance clock so that key k1 is expired
+        AdvanceManualClock(30 * 1000),
+        AddData(inputStream, InputEvent("k1", "get", 1)),
+        AddData(inputStream, InputEvent("k2", "get", -1)),
+        AdvanceManualClock(1 * 1000),
+        // validate k1 is expired and k2 is not
+        CheckNewAnswer(OutputEvent("k2", 2, isTTLValue = false, -1)),
+        // validate k1 is deleted from state
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // validate k2 exists in state
+        AddData(inputStream, InputEvent("k2", "get_ttl_value_from_state", -1)),
+        AddData(inputStream, InputEvent("k2", "get_values_in_ttl_state", -1)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k2", -1, isTTLValue = true, 92000),
+          OutputEvent("k2", -1, isTTLValue = true, 92000))
+      )
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to