This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new be080703688f [SPARK-47673][SS] Implementing TTL for ListState
be080703688f is described below

commit be080703688f8c59f8e7a0b24ce747d9ba14264e
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Tue Apr 16 11:04:15 2024 +0900

    [SPARK-47673][SS] Implementing TTL for ListState
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for expiring state based on TTL for ListState. Using 
this functionality, Spark users can specify a TTL Mode for transformWithState 
operator, and provide a ttlDuration for each value in ListState. TTL support 
for Map State will be added in future PRs. Once the ttlDuration has expired, 
the value will not be returned as part of get() and would be cleaned up at the 
end of the micro-batch.
    
    ### Why are the changes needed?
    
    These changes are needed to support TTL for ListState. The PR supports 
specifying ttl for processing time.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, modifies the ListState interface for specifying ttlDuration
    
    ### How was this patch tested?
    
    Added the TransformWithListStateTTLSuite, ListStateSuite, 
StatefulProcessorHandleSuite
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45932 from ericm-db/ls-ttl.
    
    Authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/streaming/StatefulProcessorHandle.scala    |  20 +
 .../execution/streaming/ListStateImplWithTTL.scala | 220 ++++++++++
 .../streaming/StateTypesEncoderUtils.scala         |   5 +
 .../streaming/StatefulProcessorHandleImpl.scala    |  32 ++
 .../spark/sql/execution/streaming/TTLState.scala   |  32 +-
 .../streaming/TransformWithStateExec.scala         |   2 +
 .../streaming/ValueStateImplWithTTL.scala          |  46 +--
 .../execution/streaming/state/ListStateSuite.scala |  90 +++-
 .../state/StatefulProcessorHandleSuite.scala       |  25 +-
 .../streaming/state/ValueStateSuite.scala          |   4 +-
 .../streaming/TransformWithListStateTTLSuite.scala | 454 +++++++++++++++++++++
 ...Suite.scala => TransformWithStateTTLTest.scala} | 220 +---------
 .../TransformWithValueStateTTLSuite.scala          | 270 +-----------
 13 files changed, 911 insertions(+), 509 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index e65667206ded..f662b685c4e4 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -72,6 +72,26 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
    */
   def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
 
+  /**
+   * Function to create new or return existing list state variable of given 
type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any values in listState which have expired after 
ttlDuration will not
+   * be returned on get() and will be eventually removed from the state.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  def getListState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ListState[T]
+
   /**
    * Creates new or returns existing map state associated with stateName.
    * The MapState persists Key-Value pairs of type [K, V].
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
new file mode 100644
index 000000000000..32bc21cea6ed
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{ListState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for a list state state 
associated with state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlConfig  - TTL configuration for values  stored in this state
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam S - data type of object that will be stored
+ */
+class ListStateImplWithTTL[S](
+    store: StateStore,
+    stateName: String,
+    keyExprEnc: ExpressionEncoder[Any],
+    valEncoder: Encoder[S],
+    ttlConfig: TTLConfig,
+    batchTimestampMs: Long)
+  extends SingleKeyTTLStateImpl(stateName, store, batchTimestampMs) with 
ListState[S] {
+
+  private lazy val keySerializer = keyExprEnc.createSerializer()
+
+  private lazy val stateTypesEncoder = StateTypesEncoder(
+    keySerializer, valEncoder, stateName, hasTtl = true)
+
+  private lazy val ttlExpirationMs =
+    StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+    store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+      NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = 
true)
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+    get().nonEmpty
+  }
+
+  /**
+   * Get the state value if it exists. If the state does not exist in state 
store, an
+   * empty iterator is returned.
+   */
+  override def get(): Iterator[S] = {
+    val encodedKey = stateTypesEncoder.encodeGroupingKey()
+    val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+
+    new NextIterator[S] {
+
+      override protected def getNext(): S = {
+        val iter = unsafeRowValuesIterator.dropWhile { row =>
+          stateTypesEncoder.isExpired(row, batchTimestampMs)
+        }
+
+        if (iter.hasNext) {
+          val currentRow = iter.next()
+          stateTypesEncoder.decodeValue(currentRow)
+        } else {
+          finished = true
+          null.asInstanceOf[S]
+        }
+      }
+
+      override protected def close(): Unit = {}
+    }
+  }
+
+  /** Update the value of the list. */
+  override def put(newState: Array[S]): Unit = {
+    validateNewState(newState)
+
+    val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    val encodedKey = 
stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey)
+    var isFirst = true
+
+    newState.foreach { v =>
+      val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
+      if (isFirst) {
+        store.put(encodedKey, encodedValue, stateName)
+        isFirst = false
+      } else {
+        store.merge(encodedKey, encodedValue, stateName)
+      }
+    }
+    upsertTTLForStateKey(serializedGroupingKey)
+  }
+
+  /** Append an entry to the list. */
+  override def appendValue(newState: S): Unit = {
+    StateStoreErrors.requireNonNullStateValue(newState, stateName)
+    val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    
store.merge(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+      stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
+    upsertTTLForStateKey(serializedGroupingKey)
+  }
+
+  /** Append an entire list to the existing value. */
+  override def appendList(newState: Array[S]): Unit = {
+    validateNewState(newState)
+
+    val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+    val encodedKey = 
stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey)
+    newState.foreach { v =>
+      val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
+      store.merge(encodedKey, encodedValue, stateName)
+    }
+    upsertTTLForStateKey(serializedGroupingKey)
+  }
+
+  /** Remove this state. */
+  override def clear(): Unit = {
+    store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+  }
+
+  private def validateNewState(newState: Array[S]): Unit = {
+    StateStoreErrors.requireNonNullStateValue(newState, stateName)
+    StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
+
+    newState.foreach { v =>
+      StateStoreErrors.requireNonNullStateValue(v, stateName)
+    }
+  }
+
+  /**
+   * Loops through all the values associated with the grouping key, and removes
+   * the expired elements from the list.
+   * @param groupingKey grouping key for which cleanup should be performed.
+   */
+  override def clearIfExpired(groupingKey: Array[Byte]): Long = {
+    var numValuesExpired = 0L
+    val encodedGroupingKey = 
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
+    val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, 
stateName)
+    // We clear the list, and use the iterator to put back all of the 
non-expired values
+    store.remove(encodedGroupingKey, stateName)
+    var isFirst = true
+    unsafeRowValuesIterator.foreach { encodedValue =>
+      if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) {
+        if (isFirst) {
+          store.put(encodedGroupingKey, encodedValue, stateName)
+          isFirst = false
+        } else {
+          store.merge(encodedGroupingKey, encodedValue, stateName)
+        }
+      } else {
+        numValuesExpired += 1
+      }
+    }
+    numValuesExpired
+  }
+
+  private def upsertTTLForStateKey(serializedGroupingKey: Array[Byte]): Unit = 
{
+    upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey)
+  }
+
+  /*
+    * Internal methods to probe state for testing. The below methods exist for 
unit tests
+    * to read the state ttl values, and ensure that values are persisted 
correctly in
+    * the underlying state store.
+    */
+
+  /**
+   * Retrieves the value from State even if its expired. This method is used
+   * in tests to read the state store value, and ensure if its cleaned up at 
the
+   * end of the micro-batch.
+   */
+  private[sql] def getWithoutEnforcingTTL(): Iterator[S] = {
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, 
stateName)
+    unsafeRowValuesIterator.map { valueUnsafeRow =>
+      stateTypesEncoder.decodeValue(valueUnsafeRow)
+    }
+  }
+
+  /**
+   * Read the ttl value associated with the grouping key.
+   */
+  private[sql] def getTTLValues(): Iterator[(S, Long)] = {
+    val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+    val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, 
stateName)
+    unsafeRowValuesIterator.map { valueUnsafeRow =>
+      (stateTypesEncoder.decodeValue(valueUnsafeRow),
+        stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
+    }
+  }
+
+  /**
+   * Get all ttl values stored in ttl state for current implicit
+   * grouping key.
+   */
+  private[sql] def getValuesInTTLState(): Iterator[Long] = {
+    getValuesInTTLState(stateTypesEncoder.serializeGroupingKey())
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index b2dba7668d62..56b0731e0db4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -139,6 +139,11 @@ class StateTypesEncoder[GK, V](
       Some(expirationMs)
     }
   }
+
+  def isExpired(row: UnsafeRow, batchTimestampMs: Long): Boolean = {
+    val expirationMs = decodeTtlExpirationMs(row)
+    expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs))
+  }
 }
 
 object StateTypesEncoder {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index b7fad69a9e0c..885df96a206a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -245,6 +245,38 @@ class StatefulProcessorHandleImpl(
     resultState
   }
 
+  /**
+   * Function to create new or return existing list state variable of given 
type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any values in listState which have expired after 
ttlDuration will not
+   * returned on get() and will be eventually removed from the state.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ListState[T] = {
+
+    verifyStateVarOperations("get_list_state")
+    validateTTLConfig(ttlConfig, stateName)
+
+    assert(batchTimestampMs.isDefined)
+    val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
+      keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
+    incrementMetric("numListStateWithTTLVars")
+    ttlStates.add(listStateWithTTL)
+
+    listStateWithTTL
+  }
+
   override def getMapState[K, V](
       stateName: String,
       userKeyEnc: Encoder[K],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
index 12e31791839a..b245f8fc14d4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -73,9 +73,9 @@ trait TTLState {
    *
    * @param groupingKey grouping key for which cleanup should be performed.
    *
-   * @return true if the state was cleared, false otherwise.
+   * @return how many state objects were cleaned up.
    */
-  def clearIfExpired(groupingKey: Array[Byte]): Boolean
+  def clearIfExpired(groupingKey: Array[Byte]): Long
 }
 
 /**
@@ -118,9 +118,7 @@ abstract class SingleKeyTTLStateImpl(
       StateTTL.isExpired(expirationMs, ttlExpirationMs)
     }.foreach { kv =>
       val groupingKey = kv.key.getBinary(1)
-      if (clearIfExpired(groupingKey)) {
-        numValuesExpired += 1
-      }
+      numValuesExpired += clearIfExpired(groupingKey)
       store.remove(kv.key, ttlColumnFamilyName)
     }
     numValuesExpired
@@ -141,6 +139,30 @@ abstract class SingleKeyTTLStateImpl(
       }
     }
   }
+
+  private[sql] def getValuesInTTLState(groupingKey: Array[Byte]): 
Iterator[Long] = {
+    val ttlIterator = ttlIndexIterator()
+    var nextValue: Option[Long] = None
+
+    new Iterator[Long] {
+      override def hasNext: Boolean = {
+        while (nextValue.isEmpty && ttlIterator.hasNext) {
+          val nextTtlValue = ttlIterator.next()
+          val valueGroupingKey = nextTtlValue.groupingKey
+          if (valueGroupingKey sameElements groupingKey) {
+            nextValue = Some(nextTtlValue.expirationMs)
+          }
+        }
+        nextValue.isDefined
+      }
+
+      override def next(): Long = {
+        val result = nextValue.get
+        nextValue = None
+        result
+      }
+    }
+  }
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index be01d55b3f4d..f5d2610d78d9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -311,6 +311,8 @@ case class TransformWithStateExec(
       // metrics around TTL
       StatefulOperatorCustomSumMetric("numValueStateWithTTLVars",
         "Number of value state variables with TTL"),
+      StatefulOperatorCustomSumMetric("numListStateWithTTLVars",
+        "Number of list state variables with TTL"),
       StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry",
         "Number of values removed due to TTL expiry")
     )
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
index 89fe3c2b94ef..dbfa4586dc0a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore}
 import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
@@ -75,7 +74,7 @@ class ValueStateImplWithTTL[S](
     if (retRow != null) {
       val resState = stateTypesEncoder.decodeValue(retRow)
 
-      if (!isExpired(retRow)) {
+      if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
         resState
       } else {
         null.asInstanceOf[S]
@@ -99,25 +98,20 @@ class ValueStateImplWithTTL[S](
     store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
   }
 
-  def clearIfExpired(groupingKey: Array[Byte]): Boolean = {
+  def clearIfExpired(groupingKey: Array[Byte]): Long = {
     val encodedGroupingKey = 
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
     val retRow = store.get(encodedGroupingKey, stateName)
 
-    var result = false
+    var result = 0L
     if (retRow != null) {
-      if (isExpired(retRow)) {
+      if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
         store.remove(encodedGroupingKey, stateName)
-        result = true
+        result = 1L
       }
     }
     result
   }
 
-  private def isExpired(valueRow: UnsafeRow): Boolean = {
-    val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
-    expirationMs.exists(StateTTL.isExpired(_, batchTimestampMs))
-  }
-
   /*
    * Internal methods to probe state for testing. The below methods exist for 
unit tests
    * to read the state ttl values, and ensure that values are persisted 
correctly in
@@ -144,12 +138,15 @@ class ValueStateImplWithTTL[S](
   /**
    * Read the ttl value associated with the grouping key.
    */
-  private[sql] def getTTLValue(): Option[Long] = {
+  private[sql] def getTTLValue(): Option[(S, Long)] = {
     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
     val retRow = store.get(encodedGroupingKey, stateName)
 
+    // if the returned row is not null, we want to return the value associated 
with the
+    // ttlExpiration
     if (retRow != null) {
-      stateTypesEncoder.decodeTtlExpirationMs(retRow)
+      val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(retRow)
+      ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(retRow), 
expiration))
     } else {
       None
     }
@@ -160,28 +157,7 @@ class ValueStateImplWithTTL[S](
    * grouping key.
    */
   private[sql] def getValuesInTTLState(): Iterator[Long] = {
-    val ttlIterator = ttlIndexIterator()
-    val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey()
-    var nextValue: Option[Long] = None
-
-    new Iterator[Long] {
-      override def hasNext: Boolean = {
-        while (nextValue.isEmpty && ttlIterator.hasNext) {
-          val nextTtlValue = ttlIterator.next()
-          val groupingKey = nextTtlValue.groupingKey
-          if (groupingKey sameElements implicitGroupingKey) {
-            nextValue = Some(nextTtlValue.expirationMs)
-          }
-        }
-        nextValue.isDefined
-      }
-
-      override def next(): Long = {
-        val result = nextValue.get
-        nextValue = None
-        result
-      }
-    }
+    getValuesInTTLState(stateTypesEncoder.serializeGroupingKey())
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index 5eb48a86e342..1e6136fd38a3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.time.Duration
 import java.util.UUID
 
-import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.{SparkIllegalArgumentException, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
-import org.apache.spark.sql.streaming.{ListState, TimeMode, ValueState}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
ListStateImplWithTTL, StatefulProcessorHandleImpl}
+import org.apache.spark.sql.streaming.{ListState, TimeMode, TTLConfig, 
ValueState}
 
 /**
  * Class that adds unit tests for ListState types used in arbitrary stateful
@@ -160,4 +161,87 @@ class ListStateSuite extends StateVariableSuiteBase {
       assert(listState1.get().toSeq === Seq.empty[Long])
     }
   }
+
+  test(s"test List state TTL") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val timestampMs = 10
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs))
+
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val testState: ListStateImplWithTTL[String] = 
handle.getListState[String]("testState",
+        Encoders.STRING, ttlConfig).asInstanceOf[ListStateImplWithTTL[String]]
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.put(Array("v1", "v2", "v3"))
+      assert(testState.get().toSeq === Seq("v1", "v2", "v3"))
+      assert(testState.getWithoutEnforcingTTL().toSeq === Seq("v1", "v2", 
"v3"))
+
+      val ttlExpirationMs = timestampMs + 60000
+      var ttlValues = testState.getTTLValues()
+      assert(ttlValues.nonEmpty)
+      assert(ttlValues.forall(_._2 === ttlExpirationMs))
+      var ttlStateValueIterator = testState.getValuesInTTLState()
+      assert(ttlStateValueIterator.hasNext)
+
+      // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
+      val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs))
+
+      val nextBatchTestState: ListStateImplWithTTL[String] =
+        nextBatchHandle.getListState[String]("testState", Encoders.STRING, 
ttlConfig)
+          .asInstanceOf[ListStateImplWithTTL[String]]
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+
+      // ensure get does not return the expired value
+      assert(!nextBatchTestState.exists())
+      assert(nextBatchTestState.get().isEmpty)
+
+      // ttl value should still exist in state
+      ttlValues = nextBatchTestState.getTTLValues()
+      assert(ttlValues.nonEmpty)
+      assert(ttlValues.forall(_._2 === ttlExpirationMs))
+      ttlStateValueIterator = nextBatchTestState.getValuesInTTLState()
+      assert(ttlStateValueIterator.hasNext)
+      assert(ttlStateValueIterator.next() === ttlExpirationMs)
+      assert(ttlStateValueIterator.isEmpty)
+
+      // getWithoutTTL should still return the expired value
+      assert(nextBatchTestState.getWithoutEnforcingTTL().toSeq === Seq("v1", 
"v2", "v3"))
+
+      nextBatchTestState.clear()
+      assert(!nextBatchTestState.exists())
+      assert(nextBatchTestState.get().isEmpty)
+    }
+  }
+
+  test("test negative or zero TTL duration throws error") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val batchTimestampMs = 10
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
+        TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
+
+      Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+        val ttlConfig = TTLConfig(ttlDuration)
+        val ex = intercept[SparkUnsupportedOperationException] {
+          handle.getListState[String]("testState", Encoders.STRING, ttlConfig)
+        }
+
+        checkError(
+          ex,
+          errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE",
+          parameters = Map(
+            "operationType" -> "update",
+            "stateName" -> "testState"
+          ),
+          matchPVals = true
+        )
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index e9ffe4ca9269..aafbf4df60af 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -219,7 +219,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     }
   }
 
-  test(s"ttl States are populated for timeMode=ProcessingTime") {
+  test("ttl States are populated for valueState and timeMode=ProcessingTime") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
@@ -237,13 +237,32 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     }
   }
 
-  test(s"ttl States are not populated for timeMode=None") {
+  test("ttl States are populated for listState and timeMode=ProcessingTime") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store,
+        UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(),
+        batchTimestampMs = Some(10))
+
+      val listStateWithTTL = handle.getListState("testState",
+        Encoders.STRING, TTLConfig(Duration.ofHours(1)))
+
+      // create another state without TTL, this should not be captured in the 
handle
+      handle.getListState("testState", Encoders.STRING)
+
+      assert(handle.ttlStates.size() === 1)
+      assert(handle.ttlStates.get(0) === listStateWithTTL)
+    }
+  }
+
+  test("ttl States are not populated for timeMode=None") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
         UUID.randomUUID(), keyExprEncoder, TimeMode.None())
 
-      handle.getValueState("testState", Encoders.STRING)
+      handle.getValueState("testValueState", Encoders.STRING)
+      handle.getListState("testListState", Encoders.STRING)
 
       assert(handle.ttlStates.isEmpty)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index d2747e2976f4..e5875da947a3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -324,7 +324,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val ttlExpirationMs = timestampMs + 60000
       var ttlValue = testState.getTTLValue()
       assert(ttlValue.isDefined)
-      assert(ttlValue.get === ttlExpirationMs)
+      assert(ttlValue.get._2 === ttlExpirationMs)
       var ttlStateValueIterator = testState.getValuesInTTLState()
       assert(ttlStateValueIterator.hasNext)
 
@@ -346,7 +346,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
       // ttl value should still exist in state
       ttlValue = nextBatchTestState.getTTLValue()
       assert(ttlValue.isDefined)
-      assert(ttlValue.get === ttlExpirationMs)
+      assert(ttlValue.get._2 === ttlExpirationMs)
       ttlStateValueIterator = nextBatchTestState.getValuesInTTLState()
       assert(ttlStateValueIterator.hasNext)
       assert(ttlStateValueIterator.next() === ttlExpirationMs)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala
new file mode 100644
index 000000000000..299a3346b2e5
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{ListStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class ListStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent] {
+
+  @transient private var _listState: ListStateImplWithTTL[Int] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _listState = getHandle
+      .getListState("listState", Encoders.scalaInt, ttlConfig)
+      .asInstanceOf[ListStateImplWithTTL[Int]]
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputEvent],
+      timerValues: TimerValues,
+    expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+    var results = List[OutputEvent]()
+
+    inputRows.foreach { row =>
+      val resultIter = processRow(row, _listState)
+      resultIter.foreach { r =>
+        results = r :: results
+      }
+    }
+
+    results.iterator
+  }
+
+  def processRow(
+      row: InputEvent,
+      listState: ListStateImplWithTTL[Int]): Iterator[OutputEvent] = {
+
+    var results = List[OutputEvent]()
+    val key = row.key
+    if (row.action == "get") {
+      val currState = listState.get()
+      currState.foreach { v =>
+        results = OutputEvent(key, v, isTTLValue = false, -1) :: results
+      }
+    } else if (row.action == "get_without_enforcing_ttl") {
+      val currState = listState.getWithoutEnforcingTTL()
+      currState.foreach { v =>
+        results = OutputEvent(key, v, isTTLValue = false, -1) :: results
+      }
+    } else if (row.action == "get_ttl_value_from_state") {
+      val ttlValues = listState.getTTLValues()
+      ttlValues.foreach { ttlValue =>
+        results = OutputEvent(key, ttlValue._1, isTTLValue = true, 
ttlValue._2) :: results
+      }
+    } else if (row.action == "put") {
+      listState.put(Array(row.value))
+    } else if (row.action == "append") {
+      listState.appendValue(row.value)
+    } else if (row.action == "get_values_in_ttl_state") {
+      val ttlValues = listState.getValuesInTTLState()
+      ttlValues.foreach { v =>
+        results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+      }
+    }
+
+    results.iterator
+  }
+}
+
+/**
+ * Test suite for testing list state with TTL.
+ * We use the base TTL suite with a list state processor.
+ */
+class TransformWithListStateTTLSuite extends TransformWithStateTTLTest {
+
+  import testImplicits._
+
+  override def getProcessor(ttlConfig: TTLConfig):
+    StatefulProcessor[String, InputEvent, OutputEvent] = {
+      new ListStateTTLProcessor(ttlConfig)
+  }
+
+  override def getStateTTLMetricName: String = "numListStateWithTTLVars"
+
+  test("verify iterator works with expired values in beginning of list") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+      val inputStream = MemoryStream[InputEvent]
+      val result = inputStream.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(
+          getProcessor(ttlConfig),
+          TimeMode.ProcessingTime(),
+          OutputMode.Append())
+      val clock = new StreamManualClock
+      testStream(result)(
+        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+        AddData(inputStream, InputEvent("k1", "put", 1)),
+        AdvanceManualClock(1 * 1000),
+        AddData(inputStream,
+          InputEvent("k1", "append", 2),
+          InputEvent("k1", "append", 3)
+        ),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // get ttl values
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, 
null)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k1", 1, isTTLValue = true, 61000),
+          OutputEvent("k1", 2, isTTLValue = true, 62000),
+          OutputEvent("k1", 3, isTTLValue = true, 62000)
+        ),
+        // advance clock to add elements with later TTL
+        AdvanceManualClock(45 * 1000), // batch timestamp: 48000
+        AddData(inputStream,
+          InputEvent("k1", "append", 4),
+          InputEvent("k1", "append", 5),
+          InputEvent("k1", "append", 6)
+        ),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(),
+        // get ttl values
+        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, 
null)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k1", 1, isTTLValue = true, 61000),
+          OutputEvent("k1", 2, isTTLValue = true, 62000),
+          OutputEvent("k1", 3, isTTLValue = true, 62000),
+          OutputEvent("k1", 4, isTTLValue = true, 109000),
+          OutputEvent("k1", 5, isTTLValue = true, 109000),
+          OutputEvent("k1", 6, isTTLValue = true, 109000)
+        ),
+        // advance clock to expire the first three elements
+        AdvanceManualClock(15 * 1000), // batch timestamp: 65000
+        AddData(inputStream, InputEvent("k1", "get", -1, null)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k1", 4, isTTLValue = false, -1),
+          OutputEvent("k1", 5, isTTLValue = false, -1),
+          OutputEvent("k1", 6, isTTLValue = false, -1)
+        ),
+        // ensure that expired elements are no longer in state
+        AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, 
null)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k1", 4, isTTLValue = false, -1),
+          OutputEvent("k1", 5, isTTLValue = false, -1),
+          OutputEvent("k1", 6, isTTLValue = false, -1)
+        ),
+        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, 
null)),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(
+          OutputEvent("k1", -1, isTTLValue = true, 109000)
+        )
+      )
+    }
+  }
+
+  // We can only change the TTL of a state variable upon query restart.
+  // Therefore, only on query restart, will elements not be stored in
+  // ascending order of TTL.
+  // The following test cases will test the case where the elements are not 
stored in
+  // ascending order of TTL by stopping the query, setting the new TTL, and 
restarting
+  // the query to check that the expired elements in the middle or end of the 
list
+  // are not returned.
+  test("verify iterator works with expired values in middle of list") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+      withTempDir { checkpointLocation =>
+        // starting the query with a TTL of 3 minutes
+        val ttlConfig1 = TTLConfig(ttlDuration = Duration.ofMinutes(3))
+        val inputStream = MemoryStream[InputEvent]
+        val result1 = inputStream.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(
+            getProcessor(ttlConfig1),
+            TimeMode.ProcessingTime(),
+            OutputMode.Append())
+
+        val clock = new StreamManualClock
+        // add 3 elements with a duration of 3 minutes
+        // batch timestamp at the end of this block will be 4000
+        testStream(result1)(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = checkpointLocation.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "put", 1)),
+          AdvanceManualClock(1 * 1000),
+          AddData(inputStream, InputEvent("k1", "append", 2)),
+          AddData(inputStream, InputEvent("k1", "append", 3)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          // get ttl values
+          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = true, 181000),
+            OutputEvent("k1", 2, isTTLValue = true, 182000),
+            OutputEvent("k1", 3, isTTLValue = true, 182000)
+          ),
+          AddData(inputStream, InputEvent("k1", "get", -1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1)
+          ),
+          StopStream
+        )
+
+        // Here, we are restarting the query with a new TTL of 15 seconds
+        // so that we can add elements to the middle of the list that will
+        // expire quickly
+        // batch timestamp at the end of this block will be 7000
+        val ttlConfig2 = TTLConfig(ttlDuration = Duration.ofSeconds(15))
+        val result2 = inputStream.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(
+            getProcessor(ttlConfig2),
+            TimeMode.ProcessingTime(),
+            OutputMode.Append())
+        // add 3 elements with a duration of 15 seconds
+        testStream(result2)(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = checkpointLocation.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "append", 4)),
+          AddData(inputStream, InputEvent("k1", "append", 5)),
+          AddData(inputStream, InputEvent("k1", "append", 6)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          // get all elements without enforcing ttl
+          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1),
+            OutputEvent("k1", 4, isTTLValue = false, -1),
+            OutputEvent("k1", 5, isTTLValue = false, -1),
+            OutputEvent("k1", 6, isTTLValue = false, -1)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = true, 181000),
+            OutputEvent("k1", 2, isTTLValue = true, 182000),
+            OutputEvent("k1", 3, isTTLValue = true, 182000),
+            OutputEvent("k1", 4, isTTLValue = true, 20000),
+            OutputEvent("k1", 5, isTTLValue = true, 20000),
+            OutputEvent("k1", 6, isTTLValue = true, 20000)
+          ),
+          StopStream
+        )
+
+        // Restart the stream with the first TTL config to add elements to the 
end
+        // with a TTL of 3 minutes
+        testStream(result1)(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = checkpointLocation.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "append", 7)),
+          AddData(inputStream, InputEvent("k1", "append", 8)),
+          AddData(inputStream, InputEvent("k1", "append", 9)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          // advance clock to expire the middle three elements
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, 
null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", -1, isTTLValue = true, 20000),
+            OutputEvent("k1", -1, isTTLValue = true, 181000),
+            OutputEvent("k1", -1, isTTLValue = true, 182000),
+            OutputEvent("k1", -1, isTTLValue = true, 188000)
+          ),
+          // progress batch timestamp from 9000 to 54000, expiring the middle
+          // three elements.
+          AdvanceManualClock(45 * 1000),
+          // Get all elements in the list
+          AddData(inputStream, InputEvent("k1", "get", -1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1),
+            OutputEvent("k1", 7, isTTLValue = false, -1),
+            OutputEvent("k1", 8, isTTLValue = false, -1),
+            OutputEvent("k1", 9, isTTLValue = false, -1)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1),
+            OutputEvent("k1", 7, isTTLValue = false, -1),
+            OutputEvent("k1", 8, isTTLValue = false, -1),
+            OutputEvent("k1", 9, isTTLValue = false, -1)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, 
null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", -1, isTTLValue = true, 181000),
+            OutputEvent("k1", -1, isTTLValue = true, 182000),
+            OutputEvent("k1", -1, isTTLValue = true, 188000)
+          ),
+          StopStream
+        )
+      }
+    }
+  }
+
+  test("verify iterator works with expired values in end of list") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+      withTempDir { checkpointLocation =>
+        // first TTL config to start the query with a TTL of 2 minutes
+        val inputStream = MemoryStream[InputEvent]
+        val ttlConfig1 = TTLConfig(ttlDuration = Duration.ofMinutes(2))
+        val result1 = inputStream.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(
+            getProcessor(ttlConfig1),
+            TimeMode.ProcessingTime(),
+            OutputMode.Append())
+
+        // second TTL config we will use to start the query with a TTL of 1 
minute
+        val ttlConfig2 = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+        val result2 = inputStream.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(
+            getProcessor(ttlConfig2),
+            TimeMode.ProcessingTime(),
+            OutputMode.Append())
+
+        val clock = new StreamManualClock
+        // add 3 elements with a duration of a minute
+        // expected batch timestamp at the end of the stream is 4000
+        testStream(result1)(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = checkpointLocation.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "put", 1)),
+          AdvanceManualClock(1 * 1000),
+          AddData(inputStream, InputEvent("k1", "append", 2)),
+          AddData(inputStream, InputEvent("k1", "append", 3)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          // get ttl values
+          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = true, 121000),
+            OutputEvent("k1", 2, isTTLValue = true, 122000),
+            OutputEvent("k1", 3, isTTLValue = true, 122000)
+          ),
+          AddData(inputStream, InputEvent("k1", "get", -1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1)
+          ),
+          StopStream
+        )
+
+        // Here, we are restarting the query with a new TTL of 1 minutes
+        // so that the elements at the end will expire before the beginning
+        // batch timestamp at the end of this block will be 7000
+        testStream(result2)(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = checkpointLocation.getAbsolutePath),
+          AddData(inputStream, InputEvent("k1", "append", 4)),
+          AddData(inputStream, InputEvent("k1", "append", 5)),
+          AddData(inputStream, InputEvent("k1", "append", 6)),
+          // advance clock to trigger processing
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(),
+          // get ttl values
+          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = true, 121000),
+            OutputEvent("k1", 2, isTTLValue = true, 122000),
+            OutputEvent("k1", 3, isTTLValue = true, 122000),
+            OutputEvent("k1", 4, isTTLValue = true, 65000),
+            OutputEvent("k1", 5, isTTLValue = true, 65000),
+            OutputEvent("k1", 6, isTTLValue = true, 65000)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, 
null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", -1, isTTLValue = true, 121000),
+            OutputEvent("k1", -1, isTTLValue = true, 122000),
+            OutputEvent("k1", -1, isTTLValue = true, 65000)
+          ),
+          // expire end values, batch timestamp from 7000 to 67000
+          AdvanceManualClock(60 * 1000),
+          AddData(inputStream, InputEvent("k1", "get", -1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1, null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", 1, isTTLValue = false, -1),
+            OutputEvent("k1", 2, isTTLValue = false, -1),
+            OutputEvent("k1", 3, isTTLValue = false, -1)
+          ),
+          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1, 
null)),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(
+            OutputEvent("k1", -1, isTTLValue = true, 121000),
+            OutputEvent("k1", -1, isTTLValue = true, 122000)
+          ),
+          StopStream
+        )
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
similarity index 60%
copy from 
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
copy to 
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
index ac596869bbb0..2ddf69aa49e0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
@@ -20,9 +20,7 @@ package org.apache.spark.sql.streaming
 import java.sql.Timestamp
 import java.time.Duration
 
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.execution.streaming.{MemoryStream, ValueStateImpl, 
ValueStateImplWithTTL}
+import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -39,146 +37,17 @@ case class OutputEvent(
     isTTLValue: Boolean,
     ttlValue: Long)
 
-object TTLInputProcessFunction {
-  def processRow(
-      row: InputEvent,
-      valueState: ValueStateImplWithTTL[Int]): Iterator[OutputEvent] = {
-    var results = List[OutputEvent]()
-    val key = row.key
-    if (row.action == "get") {
-      val currState = valueState.getOption()
-      if (currState.isDefined) {
-        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
-      }
-    } else if (row.action == "get_without_enforcing_ttl") {
-      val currState = valueState.getWithoutEnforcingTTL()
-      if (currState.isDefined) {
-        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
-      }
-    } else if (row.action == "get_ttl_value_from_state") {
-      val ttlExpiration = valueState.getTTLValue()
-      if (ttlExpiration.isDefined) {
-        results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) 
:: results
-      }
-    } else if (row.action == "put") {
-      valueState.update(row.value)
-    } else if (row.action == "get_values_in_ttl_state") {
-      val ttlValues = valueState.getValuesInTTLState()
-      ttlValues.foreach { v =>
-        results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
-      }
-    }
-
-    results.iterator
-  }
-
-  def processNonTTLStateRow(
-      row: InputEvent,
-      valueState: ValueStateImpl[Int]): Iterator[OutputEvent] = {
-    var results = List[OutputEvent]()
-    val key = row.key
-    if (row.action == "get") {
-      val currState = valueState.getOption()
-      if (currState.isDefined) {
-        results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
-      }
-    } else if (row.action == "put") {
-      valueState.update(row.value)
-    }
-
-    results.iterator
-  }
-}
-
-class ValueStateTTLProcessor(ttlConfig: TTLConfig)
-  extends StatefulProcessor[String, InputEvent, OutputEvent]
-  with Logging {
-
-  @transient private var _valueState: ValueStateImplWithTTL[Int] = _
-
-  override def init(
-      outputMode: OutputMode,
-      timeMode: TimeMode): Unit = {
-    _valueState = getHandle
-      .getValueState("valueState", Encoders.scalaInt, ttlConfig)
-      .asInstanceOf[ValueStateImplWithTTL[Int]]
-  }
-
-  override def handleInputRows(
-      key: String,
-      inputRows: Iterator[InputEvent],
-      timerValues: TimerValues,
-      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
-    var results = List[OutputEvent]()
-
-    inputRows.foreach { row =>
-      val resultIter = TTLInputProcessFunction.processRow(row, _valueState)
-      resultIter.foreach { r =>
-        results = r :: results
-      }
-    }
-
-    results.iterator
-  }
-}
-
-case class MultipleValueStatesTTLProcessor(
-    ttlKey: String,
-    noTtlKey: String,
-    ttlConfig: TTLConfig)
-  extends StatefulProcessor[String, InputEvent, OutputEvent]
-    with Logging {
-
-  @transient private var _valueStateWithTTL: ValueStateImplWithTTL[Int] = _
-  @transient private var _valueStateWithoutTTL: ValueStateImpl[Int] = _
-
-  override def init(
-      outputMode: OutputMode,
-      timeMode: TimeMode): Unit = {
-    _valueStateWithTTL = getHandle
-      .getValueState("valueState", Encoders.scalaInt, ttlConfig)
-      .asInstanceOf[ValueStateImplWithTTL[Int]]
-    _valueStateWithoutTTL = getHandle
-      .getValueState("valueState", Encoders.scalaInt)
-      .asInstanceOf[ValueStateImpl[Int]]
-  }
-
-  override def handleInputRows(
-      key: String,
-      inputRows: Iterator[InputEvent],
-      timerValues: TimerValues,
-      expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
-    var results = List[OutputEvent]()
-
-    if (key == ttlKey) {
-      inputRows.foreach { row =>
-        val resultIterator = TTLInputProcessFunction.processRow(row, 
_valueStateWithTTL)
-        resultIterator.foreach { r =>
-          results = r :: results
-        }
-      }
-    } else {
-      inputRows.foreach { row =>
-        val resultIterator = TTLInputProcessFunction.processNonTTLStateRow(row,
-          _valueStateWithoutTTL)
-        resultIterator.foreach { r =>
-          results = r :: results
-        }
-      }
-    }
-
-    results.iterator
-  }
-}
-
 /**
- * Tests that ttl works as expected for Value State for
- * processing time and event time based ttl.
+ * Test suite base for TransformWithState with TTL support.
  */
-class TransformWithValueStateTTLSuite
+abstract class TransformWithStateTTLTest
   extends StreamTest {
   import testImplicits._
 
+  def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, 
InputEvent, OutputEvent]
+
+  def getStateTTLMetricName: String
+
   test("validate state is evicted at ttl expiry") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName) {
@@ -188,7 +57,7 @@ class TransformWithValueStateTTLSuite
         val result = inputStream.toDS()
           .groupByKey(x => x.key)
           .transformWithState(
-            new ValueStateTTLProcessor(ttlConfig),
+            getProcessor(ttlConfig),
             TimeMode.ProcessingTime(),
             OutputMode.Append())
 
@@ -219,7 +88,7 @@ class TransformWithValueStateTTLSuite
           // ensure ttl values were added correctly
           AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1)),
           AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+          CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = true, 61000)),
           AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
           AdvanceManualClock(1 * 1000),
           CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
@@ -246,7 +115,7 @@ class TransformWithValueStateTTLSuite
             // for stateful operator
             val progData = q.recentProgress.filter(prog => 
prog.stateOperators.size > 0)
             assert(progData.filter(prog =>
-              
prog.stateOperators(0).customMetrics.get("numValueStateWithTTLVars") > 0).size 
> 0)
+              prog.stateOperators(0).customMetrics.get(getStateTTLMetricName) 
> 0).size > 0)
             assert(progData.filter(prog =>
               prog.stateOperators(0).customMetrics
                 .get("numValuesRemovedDueToTTLExpiry") > 0).size > 0)
@@ -264,7 +133,7 @@ class TransformWithValueStateTTLSuite
       val result = inputStream.toDS()
         .groupByKey(x => x.key)
         .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
+          getProcessor(ttlConfig),
           TimeMode.ProcessingTime(),
           OutputMode.Append())
 
@@ -282,7 +151,7 @@ class TransformWithValueStateTTLSuite
         // ensure ttl values were added correctly
         AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
         AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = true, 61000)),
         AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
         AdvanceManualClock(1 * 1000),
         CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
@@ -297,7 +166,7 @@ class TransformWithValueStateTTLSuite
         // validate ttl value is updated in the state
         AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
         AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = true, 95000)),
         // validate ttl state has both ttl values present
         AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
         AdvanceManualClock(1 * 1000),
@@ -326,7 +195,7 @@ class TransformWithValueStateTTLSuite
       val result = inputStream.toDS()
         .groupByKey(x => x.key)
         .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
+          getProcessor(ttlConfig),
           TimeMode.ProcessingTime(),
           OutputMode.Append())
 
@@ -346,7 +215,7 @@ class TransformWithValueStateTTLSuite
         // ensure ttl values were added correctly
         AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
         AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
+        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = true, 61000)),
         AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
         AdvanceManualClock(1 * 1000),
         CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
@@ -369,61 +238,6 @@ class TransformWithValueStateTTLSuite
     }
   }
 
-  test("validate multiple value states") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      val ttlKey = "k1"
-      val noTtlKey = "k2"
-      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-      val inputStream = MemoryStream[InputEvent]
-      val result = inputStream.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(
-          MultipleValueStatesTTLProcessor(ttlKey, noTtlKey, ttlConfig),
-          TimeMode.ProcessingTime(),
-          OutputMode.Append())
-
-      val clock = new StreamManualClock
-      testStream(result)(
-        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
-        AddData(inputStream, InputEvent(ttlKey, "put", 1)),
-        AddData(inputStream, InputEvent(noTtlKey, "put", 2)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // get both state values, and make sure we get unexpired value
-        AddData(inputStream, InputEvent(ttlKey, "get", -1)),
-        AddData(inputStream, InputEvent(noTtlKey, "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(
-          OutputEvent(ttlKey, 1, isTTLValue = false, -1),
-          OutputEvent(noTtlKey, 2, isTTLValue = false, -1)
-        ),
-        // ensure ttl values were added correctly, and noTtlKey has no ttl 
values
-        AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", 
-1)),
-        AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", 
-1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)),
-        AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", 
-1)),
-        AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", 
-1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)),
-        // advance clock after expiry
-        AdvanceManualClock(60 * 1000),
-        AddData(inputStream, InputEvent(ttlKey, "get", -1)),
-        AddData(inputStream, InputEvent(noTtlKey, "get", -1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        // validate ttlKey is expired, bot noTtlKey is still present
-        CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)),
-        // validate ttl value is removed in the value state column family
-        AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", 
-1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer()
-      )
-    }
-  }
-
   test("validate only expired keys are removed from the state") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
@@ -433,7 +247,7 @@ class TransformWithValueStateTTLSuite
       val result = inputStream.toDS()
         .groupByKey(x => x.key)
         .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
+          getProcessor(ttlConfig),
           TimeMode.ProcessingTime(),
           OutputMode.Append())
 
@@ -466,7 +280,7 @@ class TransformWithValueStateTTLSuite
         AddData(inputStream, InputEvent("k2", "get_values_in_ttl_state", -1)),
         AdvanceManualClock(1 * 1000),
         CheckNewAnswer(
-          OutputEvent("k2", -1, isTTLValue = true, 92000),
+          OutputEvent("k2", 2, isTTLValue = true, 92000),
           OutputEvent("k2", -1, isTTLValue = true, 92000))
       )
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index ac596869bbb0..54004b419f75 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.streaming
 
-import java.sql.Timestamp
 import java.time.Duration
 
 import org.apache.spark.internal.Logging
@@ -27,18 +26,6 @@ import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 
-case class InputEvent(
-    key: String,
-    action: String,
-    value: Int,
-    eventTime: Timestamp = null)
-
-case class OutputEvent(
-    key: String,
-    value: Int,
-    isTTLValue: Boolean,
-    ttlValue: Long)
-
 object TTLInputProcessFunction {
   def processRow(
       row: InputEvent,
@@ -56,9 +43,11 @@ object TTLInputProcessFunction {
         results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
       }
     } else if (row.action == "get_ttl_value_from_state") {
-      val ttlExpiration = valueState.getTTLValue()
-      if (ttlExpiration.isDefined) {
-        results = OutputEvent(key, -1, isTTLValue = true, ttlExpiration.get) 
:: results
+      val ttlValue = valueState.getTTLValue()
+      if (ttlValue.isDefined) {
+        val value = ttlValue.get._1
+        val ttlExpiration = ttlValue.get._2
+        results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
       }
     } else if (row.action == "put") {
       valueState.update(row.value)
@@ -171,203 +160,16 @@ case class MultipleValueStatesTTLProcessor(
   }
 }
 
-/**
- * Tests that ttl works as expected for Value State for
- * processing time and event time based ttl.
- */
-class TransformWithValueStateTTLSuite
-  extends StreamTest {
-  import testImplicits._
+class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
 
-  test("validate state is evicted at ttl expiry") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      withTempDir { dir =>
-        val inputStream = MemoryStream[InputEvent]
-        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-        val result = inputStream.toDS()
-          .groupByKey(x => x.key)
-          .transformWithState(
-            new ValueStateTTLProcessor(ttlConfig),
-            TimeMode.ProcessingTime(),
-            OutputMode.Append())
+  import testImplicits._
 
-        val clock = new StreamManualClock
-        testStream(result)(
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          AddData(inputStream, InputEvent("k1", "put", 1)),
-          // advance clock to trigger processing
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // get this state, and make sure we get unexpired value
-          AddData(inputStream, InputEvent("k1", "get", -1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // ensure ttl values were added correctly
-          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // advance clock so that state expires
-          AdvanceManualClock(60 * 1000),
-          AddData(inputStream, InputEvent("k1", "get", -1, null)),
-          AdvanceManualClock(1 * 1000),
-          // validate expired value is not returned
-          CheckNewAnswer(),
-          // ensure this state does not exist any longer in State
-          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          Execute { q =>
-            // Filter for idle progress events and then verify the custom 
metrics
-            // for stateful operator
-            val progData = q.recentProgress.filter(prog => 
prog.stateOperators.size > 0)
-            assert(progData.filter(prog =>
-              
prog.stateOperators(0).customMetrics.get("numValueStateWithTTLVars") > 0).size 
> 0)
-            assert(progData.filter(prog =>
-              prog.stateOperators(0).customMetrics
-                .get("numValuesRemovedDueToTTLExpiry") > 0).size > 0)
-          }
-        )
-      }
-    }
+  override def getProcessor(ttlConfig: TTLConfig):
+    StatefulProcessor[String, InputEvent, OutputEvent] = {
+      new ValueStateTTLProcessor(ttlConfig)
   }
 
-  test("validate state update updates the expiration timestamp") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      val inputStream = MemoryStream[InputEvent]
-      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-      val result = inputStream.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
-          TimeMode.ProcessingTime(),
-          OutputMode.Append())
-
-      val clock = new StreamManualClock
-      testStream(result)(
-        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // get this state, and make sure we get unexpired value
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // ensure ttl values were added correctly
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        // advance clock and update expiration time
-        AdvanceManualClock(30 * 1000),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        // validate value is not expired
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // validate ttl value is updated in the state
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)),
-        // validate ttl state has both ttl values present
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000),
-          OutputEvent("k1", -1, isTTLValue = true, 95000)
-        ),
-        // advance clock after older expiration value
-        AdvanceManualClock(30 * 1000),
-        // ensure unexpired value is still present in the state
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // validate that the older expiration value is removed from ttl state
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000))
-      )
-    }
-  }
-
-  test("validate state is evicted at ttl expiry for no data batch") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-    classOf[RocksDBStateStoreProvider].getName) {
-      val inputStream = MemoryStream[InputEvent]
-      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-      val result = inputStream.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
-          TimeMode.ProcessingTime(),
-          OutputMode.Append())
-
-      val clock = new StreamManualClock
-      testStream(result)(
-        StartStream(
-          Trigger.ProcessingTime("1 second"),
-          triggerClock = clock),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // get this state, and make sure we get unexpired value
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // ensure ttl values were added correctly
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        // advance clock so that state expires
-        AdvanceManualClock(60 * 1000),
-        // run a no data batch
-        CheckNewAnswer(),
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        // validate expired value is not returned
-        CheckNewAnswer(),
-        // ensure this state does not exist any longer in State
-        AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer()
-      )
-    }
-  }
+  override def getStateTTLMetricName: String = "numValueStateWithTTLVars"
 
   test("validate multiple value states") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
@@ -403,7 +205,7 @@ class TransformWithValueStateTTLSuite
         AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", 
-1)),
         AddData(inputStream, InputEvent(noTtlKey, "get_ttl_value_from_state", 
-1)),
         AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent(ttlKey, -1, isTTLValue = true, 61000)),
+        CheckNewAnswer(OutputEvent(ttlKey, 1, isTTLValue = true, 61000)),
         AddData(inputStream, InputEvent(ttlKey, "get_values_in_ttl_state", 
-1)),
         AddData(inputStream, InputEvent(noTtlKey, "get_values_in_ttl_state", 
-1)),
         AdvanceManualClock(1 * 1000),
@@ -423,52 +225,4 @@ class TransformWithValueStateTTLSuite
       )
     }
   }
-
-  test("validate only expired keys are removed from the state") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName,
-      SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-      val inputStream = MemoryStream[InputEvent]
-      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-      val result = inputStream.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
-          TimeMode.ProcessingTime(),
-          OutputMode.Append())
-
-      val clock = new StreamManualClock
-      testStream(result)(
-        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // advance clock halfway to expiration ttl, and add another key
-        AdvanceManualClock(30 * 1000),
-        AddData(inputStream, InputEvent("k2", "put", 2)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // advance clock so that key k1 is expired
-        AdvanceManualClock(30 * 1000),
-        AddData(inputStream, InputEvent("k1", "get", 1)),
-        AddData(inputStream, InputEvent("k2", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        // validate k1 is expired and k2 is not
-        CheckNewAnswer(OutputEvent("k2", 2, isTTLValue = false, -1)),
-        // validate k1 is deleted from state
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // validate k2 exists in state
-        AddData(inputStream, InputEvent("k2", "get_ttl_value_from_state", -1)),
-        AddData(inputStream, InputEvent("k2", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(
-          OutputEvent("k2", -1, isTTLValue = true, 92000),
-          OutputEvent("k2", -1, isTTLValue = true, 92000))
-      )
-    }
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to