This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new afbebfbadc4b [SPARK-47331][SS] Serialization using case 
classes/primitives/POJO based on SQL encoder for Arbitrary State API v2
afbebfbadc4b is described below

commit afbebfbadc4b5e927df7c568a8afb08fc4407f58
Author: jingz-db <jing.z...@databricks.com>
AuthorDate: Mon Mar 11 09:20:44 2024 +0900

    [SPARK-47331][SS] Serialization using case classes/primitives/POJO based on 
SQL encoder for Arbitrary State API v2
    
    ### What changes were proposed in this pull request?
    
    In the new operator for arbitrary state-v2, we cannot rely on the 
session/encoder being available since the initialization for the various state 
instances happens on the executors. Hence, for the state serialization, we 
propose to let user explicitly pass in encoder for state variable and serialize 
primitives/case classes/POJO with SQL encoder. Leveraging SQL encoder can speed 
up the serialization.
    
    ### Why are the changes needed?
    
    These changes are needed for providing a dedicated serializer for state-v2.
    The changes are part of the work around adding new stateful streaming 
operator for arbitrary state mgmt that provides a bunch of new features listed 
in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939
    
    ### Does this PR introduce _any_ user-facing change?
    
    Users will need to specify the SQL encoder for their state variable:
    `def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T]`
    `def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T]`
    
    For primitive type, Encoder is something as: `Encoders.scalaLong`; for case 
class, `Encoders.product[CaseClass]`; for POJO, 
`Encoders.bean(classOf[POJOClass])`
    
    ### How was this patch tested?
    
    Unit tests for primitives, case classes, POJO separately in 
`ValueStateSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45447 from jingz-db/sql-encoder-state-v2.
    
    Authored-by: jingz-db <jing.z...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/streaming/StatefulProcessorHandle.scala    |   7 +-
 .../sql/execution/streaming/ListStateImpl.scala    |   8 +-
 .../streaming/StateTypesEncoderUtils.scala         |  41 +++++---
 .../streaming/StatefulProcessorHandleImpl.scala    |   9 +-
 .../sql/execution/streaming/ValueStateImpl.scala   |   9 +-
 .../execution/streaming/state/POJOTestClass.java   |  78 ++++++++++++++
 .../streaming/state/ValueStateSuite.scala          | 117 ++++++++++++++++++++-
 .../streaming/TransformWithListStateSuite.scala    |   7 +-
 .../sql/streaming/TransformWithStateSuite.scala    |  11 +-
 9 files changed, 250 insertions(+), 37 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 5d3390f80f6d..86bf1e85f90c 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming
 import java.io.Serializable
 
 import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.Encoder
 
 /**
  * Represents the operation handle provided to the stateful processor used in 
the
@@ -33,20 +34,22 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
    * @param stateName - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
    * @tparam T - type of state variable
    * @return - instance of ValueState of type T that can be used to store 
state persistently
    */
-  def getValueState[T](stateName: String): ValueState[T]
+  def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T]
 
   /**
    * Creates new or returns existing list state associated with stateName.
    * The ListState persists values of type T.
    *
    * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
    * @tparam T - type of state variable
    * @return - instance of ListState of type T that can be used to store state 
persistently
    */
-  def getListState[T](stateName: String): ListState[T]
+  def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
 
   /** Function to return queryInfo for currently running task */
   def getQueryInfo(): QueryInfo
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
index b6ed48dab579..d0be62293d05 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import 
org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
 import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
@@ -28,17 +29,20 @@ import org.apache.spark.sql.streaming.ListState
  *
  * @param store - reference to the StateStore instance to be used for storing 
state
  * @param stateName - name of logical state partition
+ * @param keyEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
  * @tparam S - data type of object that will be stored in the list
  */
 class ListStateImpl[S](
      store: StateStore,
      stateName: String,
-     keyExprEnc: ExpressionEncoder[Any])
+     keyExprEnc: ExpressionEncoder[Any],
+     valEncoder: Encoder[S])
   extends ListState[S] with Logging {
 
   private val keySerializer = keyExprEnc.createSerializer()
 
-  private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName)
+  private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, 
stateName)
 
   store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 
0,
     VALUE_ROW_SCHEMA, useMultipleValuesPerKey = true)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index 15d77030d57b..36758eafa392 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import org.apache.commons.lang3.SerializationUtils
-
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
 import org.apache.spark.sql.types.{BinaryType, StructType}
@@ -41,17 +41,27 @@ object StateKeyValueRowSchema {
  *
  * @param keySerializer - serializer to serialize the grouping key of type `GK`
  *     to an [[InternalRow]]
+ * @param valEncoder - SQL encoder for value of type `S`
  * @param stateName - name of logical state partition
  * @tparam GK - grouping key type
+ * @tparam V - value type
  */
-class StateTypesEncoder[GK](
+class StateTypesEncoder[GK, V](
     keySerializer: Serializer[GK],
+    valEncoder: Encoder[V],
     stateName: String) {
   import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._
 
+  /** Variables reused for conversions between byte array and UnsafeRow */
   private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA)
   private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA)
 
+  /** Variables reused for value conversions between spark sql and object */
+  private val valExpressionEnc = encoderFor(valEncoder)
+  private val objToRowSerializer = valExpressionEnc.createSerializer()
+  private val rowToObjDeserializer = 
valExpressionEnc.resolveAndBind().createDeserializer()
+  private val reuseRow = new UnsafeRow(valEncoder.schema.fields.length)
+
   // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
   // add caching for some of these calls.
   def encodeGroupingKey(): UnsafeRow = {
@@ -66,23 +76,26 @@ class StateTypesEncoder[GK](
     keyRow
   }
 
-  def encodeValue[S](value: S): UnsafeRow = {
-    val valueByteArr = 
SerializationUtils.serialize(value.asInstanceOf[Serializable])
-    val valueRow = valueProjection(InternalRow(valueByteArr))
-    valueRow
+  def encodeValue(value: V): UnsafeRow = {
+    val objRow: InternalRow = objToRowSerializer.apply(value)
+    val bytes = objRow.asInstanceOf[UnsafeRow].getBytes()
+    val valRow = valueProjection(InternalRow(bytes))
+    valRow
   }
 
-  def decodeValue[S](row: UnsafeRow): S = {
-    SerializationUtils
-      .deserialize(row.getBinary(0))
-      .asInstanceOf[S]
+  def decodeValue(row: UnsafeRow): V = {
+    val bytes = row.getBinary(0)
+    reuseRow.pointTo(bytes, bytes.length)
+    val value = rowToObjDeserializer.apply(reuseRow)
+    value
   }
 }
 
 object StateTypesEncoder {
-  def apply[GK](
+  def apply[GK, V](
       keySerializer: Serializer[GK],
-      stateName: String): StateTypesEncoder[GK] = {
-    new StateTypesEncoder[GK](keySerializer, stateName)
+      valEncoder: Encoder[V],
+      stateName: String): StateTypesEncoder[GK, V] = {
+    new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 56a325a31e33..d5dd9fcaf401 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -20,6 +20,7 @@ import java.util.UUID
 
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.{ListState, QueryInfo, 
StatefulProcessorHandle, ValueState}
@@ -112,10 +113,10 @@ class StatefulProcessorHandleImpl(
 
   def getHandleState: StatefulProcessorHandleState = currState
 
-  override def getValueState[T](stateName: String): ValueState[T] = {
+  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
     verify(currState == CREATED, s"Cannot create state variable with 
name=$stateName after " +
       "initialization is complete")
-    val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
+    val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, 
valEncoder)
     resultState
   }
 
@@ -132,10 +133,10 @@ class StatefulProcessorHandleImpl(
     store.removeColFamilyIfExists(stateName)
   }
 
-  override def getListState[T](stateName: String): ListState[T] = {
+  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
     verify(currState == CREATED, s"Cannot create state variable with 
name=$stateName after " +
       "initialization is complete")
-    val resultState = new ListStateImpl[T](store, stateName, keyEncoder)
+    val resultState = new ListStateImpl[T](store, stateName, keyEncoder, 
valEncoder)
     resultState
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index a94a49d88325..5d2b9881c78d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import 
org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA}
@@ -29,16 +30,18 @@ import org.apache.spark.sql.streaming.ValueState
  * @param store - reference to the StateStore instance to be used for storing 
state
  * @param stateName - name of logical state partition
  * @param keyEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
  * @tparam S - data type of object that will be stored
  */
 class ValueStateImpl[S](
     store: StateStore,
     stateName: String,
-    keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {
+    keyExprEnc: ExpressionEncoder[Any],
+    valEncoder: Encoder[S]) extends ValueState[S] with Logging {
 
   private val keySerializer = keyExprEnc.createSerializer()
 
-  private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName)
+  private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, 
stateName)
 
   store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 
0,
     VALUE_ROW_SCHEMA)
@@ -57,7 +60,7 @@ class ValueStateImpl[S](
   override def get(): S = {
     val retRow = getImpl()
     if (retRow != null) {
-      stateTypesEncoder.decodeValue[S](retRow)
+      stateTypesEncoder.decodeValue(retRow)
     } else {
       null.asInstanceOf[S]
     }
diff --git 
a/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java
 
b/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java
new file mode 100644
index 000000000000..0ba75f7789d0
--- /dev/null
+++ 
b/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state;
+
+/**
+ * A POJO class used for tests of arbitrary state SQL encoder.
+ */
+public class POJOTestClass {
+  // Fields
+  private String name;
+  private int id;
+
+  // Constructors
+  public POJOTestClass() {
+    // Default constructor
+  }
+
+  public POJOTestClass(String name, int id) {
+    this.name = name;
+    this.id = id;
+  }
+
+  // Getter methods
+  public String getName() {
+    return name;
+  }
+
+  public int getId() {
+    return id;
+  }
+
+  // Setter methods
+  public void setName(String name) {
+    this.name = name;
+  }
+
+  public void setId(int id) {
+    this.id = id;
+  }
+
+  // Additional methods if needed
+  public void incrementId() {
+    id++;
+    System.out.println(name + " is now " + id + "!");
+  }
+
+  // Override toString for better representation
+  @Override
+  public String toString() {
+    return "POJOTestClass{" +
+      "name='" + name + '\'' +
+      ", age=" + id +
+      '}';
+  }
+
+  // Override equals and hashCode for custom equality
+  @Override
+  public boolean equals(Object obj) {
+    POJOTestClass testObj = (POJOTestClass) obj;
+    return id == testObj.id && name.equals(testObj.name);
+  }
+}
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index be77f7a887c7..71462cb4b643 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -33,6 +33,9 @@ import org.apache.spark.sql.streaming.ValueState
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 
+/** A case class for SQL encoder test purpose */
+case class TestClass(var id: Long, var name: String)
+
 /**
  * Class that adds tests for single value ValueState types used in arbitrary 
stateful
  * operators such as transformWithState
@@ -93,7 +96,7 @@ class ValueStateSuite extends SharedSparkSession
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
       val stateName = "testState"
-      val testState: ValueState[Long] = handle.getValueState[Long]("testState")
+      val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
       assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
       val ex = intercept[Exception] {
         testState.update(123)
@@ -136,7 +139,7 @@ class ValueStateSuite extends SharedSparkSession
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
-      val testState: ValueState[Long] = handle.getValueState[Long]("testState")
+      val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(123)
       assert(testState.get() === 123)
@@ -162,8 +165,10 @@ class ValueStateSuite extends SharedSparkSession
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
-      val testState1: ValueState[Long] = 
handle.getValueState[Long]("testState1")
-      val testState2: ValueState[Long] = 
handle.getValueState[Long]("testState2")
+      val testState1: ValueState[Long] = handle.getValueState[Long](
+        "testState1", Encoders.scalaLong)
+      val testState2: ValueState[Long] = handle.getValueState[Long](
+        "testState2", Encoders.scalaLong)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState1.update(123)
       assert(testState1.get() === 123)
@@ -217,4 +222,108 @@ class ValueStateSuite extends SharedSparkSession
       matchPVals = true
     )
   }
+
+  test("test SQL encoder - Value state operations for Primitive(Double) 
instances") {
+    tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+
+      val testState: ValueState[Double] = 
handle.getValueState[Double]("testState",
+        Encoders.scalaDouble)
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update(1.0)
+      assert(testState.get().equals(1.0))
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+
+      testState.update(2.0)
+      assert(testState.get().equals(2.0))
+      testState.update(3.0)
+      assert(testState.get().equals(3.0))
+
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+    }
+  }
+
+  test("test SQL encoder - Value state operations for Primitive(Long) 
instances") {
+    tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+
+      val testState: ValueState[Long] = handle.getValueState[Long]("testState",
+        Encoders.scalaLong)
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update(1L)
+      assert(testState.get().equals(1L))
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+
+      testState.update(2L)
+      assert(testState.get().equals(2L))
+      testState.update(3L)
+      assert(testState.get().equals(3L))
+
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+    }
+  }
+
+  test("test SQL encoder - Value state operations for case class instances") {
+    tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+
+      val testState: ValueState[TestClass] = 
handle.getValueState[TestClass]("testState",
+        Encoders.product[TestClass])
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update(TestClass(1, "testcase1"))
+      assert(testState.get().equals(TestClass(1, "testcase1")))
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+
+      testState.update(TestClass(2, "testcase2"))
+      assert(testState.get() === TestClass(2, "testcase2"))
+      testState.update(TestClass(3, "testcase3"))
+      assert(testState.get() === TestClass(3, "testcase3"))
+
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+    }
+  }
+
+  test("test SQL encoder - Value state operations for POJO instances") {
+    tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+
+      val testState: ValueState[POJOTestClass] = 
handle.getValueState[POJOTestClass]("testState",
+        Encoders.bean(classOf[POJOTestClass]))
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update(new POJOTestClass("testcase1", 1))
+      assert(testState.get().equals(new POJOTestClass("testcase1", 1)))
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+
+      testState.update(new POJOTestClass("testcase2", 2))
+      assert(testState.get().equals(new POJOTestClass("testcase2", 2)))
+      testState.update(new POJOTestClass("testcase3", 3))
+      assert(testState.get().equals(new POJOTestClass("testcase3", 3)))
+
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index 3d085da4ab58..9572f7006f37 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.streaming
 
 import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
@@ -30,7 +31,7 @@ class TestListStateProcessor
   @transient var _listState: ListState[String] = _
 
   override def init(outputMode: OutputMode): Unit = {
-    _listState = getHandle.getListState("testListState")
+    _listState = getHandle.getListState("testListState", Encoders.STRING)
   }
 
   override def handleInputRows(
@@ -86,8 +87,8 @@ class ToggleSaveAndEmitProcessor
   @transient var _valueState: ValueState[Boolean] = _
 
   override def init(outputMode: OutputMode): Unit = {
-    _listState = getHandle.getListState("testListState")
-    _valueState = getHandle.getValueState("testValueState")
+    _listState = getHandle.getListState("testListState", Encoders.STRING)
+    _valueState = getHandle.getValueState("testValueState", 
Encoders.scalaBoolean)
   }
 
   override def handleInputRows(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 8a87472a023a..cc8c64c94c02 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming
 
 import org.apache.spark.{SparkException, SparkRuntimeException}
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, 
StateStoreMultipleColumnFamiliesNotSupportedException}
 import org.apache.spark.sql.internal.SQLConf
@@ -32,7 +33,7 @@ class RunningCountStatefulProcessor extends 
StatefulProcessor[String, String, (S
   @transient private var _countState: ValueState[Long] = _
 
   override def init(outputMode: OutputMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState")
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
   }
 
   override def handleInputRows(
@@ -59,8 +60,8 @@ class RunningCountMostRecentStatefulProcessor
   @transient private var _mostRecent: ValueState[String] = _
 
   override def init(outputMode: OutputMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState")
-    _mostRecent = getHandle.getValueState[String]("mostRecent")
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
+    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
   }
   override def handleInputRows(
       key: String,
@@ -88,7 +89,7 @@ class MostRecentStatefulProcessorWithDeletion
 
   override def init(outputMode: OutputMode): Unit = {
     getHandle.deleteIfExists("countState")
-    _mostRecent = getHandle.getValueState[String]("mostRecent")
+    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
   }
 
   override def handleInputRows(
@@ -116,7 +117,7 @@ class RunningCountStatefulProcessorWithError extends 
RunningCountStatefulProcess
       inputRows: Iterator[String],
       timerValues: TimerValues): Iterator[(String, String)] = {
     // Trying to create value state here should fail
-    _tempState = getHandle.getValueState[Long]("tempState")
+    _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong)
     Iterator.empty
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to