This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new afbebfbadc4b [SPARK-47331][SS] Serialization using case classes/primitives/POJO based on SQL encoder for Arbitrary State API v2 afbebfbadc4b is described below commit afbebfbadc4b5e927df7c568a8afb08fc4407f58 Author: jingz-db <jing.z...@databricks.com> AuthorDate: Mon Mar 11 09:20:44 2024 +0900 [SPARK-47331][SS] Serialization using case classes/primitives/POJO based on SQL encoder for Arbitrary State API v2 ### What changes were proposed in this pull request? In the new operator for arbitrary state-v2, we cannot rely on the session/encoder being available since the initialization for the various state instances happens on the executors. Hence, for the state serialization, we propose to let user explicitly pass in encoder for state variable and serialize primitives/case classes/POJO with SQL encoder. Leveraging SQL encoder can speed up the serialization. ### Why are the changes needed? These changes are needed for providing a dedicated serializer for state-v2. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939 ### Does this PR introduce _any_ user-facing change? Users will need to specify the SQL encoder for their state variable: `def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T]` `def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]` For primitive type, Encoder is something as: `Encoders.scalaLong`; for case class, `Encoders.product[CaseClass]`; for POJO, `Encoders.bean(classOf[POJOClass])` ### How was this patch tested? Unit tests for primitives, case classes, POJO separately in `ValueStateSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45447 from jingz-db/sql-encoder-state-v2. Authored-by: jingz-db <jing.z...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/streaming/StatefulProcessorHandle.scala | 7 +- .../sql/execution/streaming/ListStateImpl.scala | 8 +- .../streaming/StateTypesEncoderUtils.scala | 41 +++++--- .../streaming/StatefulProcessorHandleImpl.scala | 9 +- .../sql/execution/streaming/ValueStateImpl.scala | 9 +- .../execution/streaming/state/POJOTestClass.java | 78 ++++++++++++++ .../streaming/state/ValueStateSuite.scala | 117 ++++++++++++++++++++- .../streaming/TransformWithListStateSuite.scala | 7 +- .../sql/streaming/TransformWithStateSuite.scala | 11 +- 9 files changed, 250 insertions(+), 37 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 5d3390f80f6d..86bf1e85f90c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.Encoder /** * Represents the operation handle provided to the stateful processor used in the @@ -33,20 +34,22 @@ private[sql] trait StatefulProcessorHandle extends Serializable { * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * @param stateName - name of the state variable + * @param valEncoder - SQL encoder for state variable * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[T](stateName: String): ValueState[T] + def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] /** * Creates new or returns existing list state associated with stateName. * The ListState persists values of type T. * * @param stateName - name of the state variable + * @param valEncoder - SQL encoder for state variable * @tparam T - type of state variable * @return - instance of ListState of type T that can be used to store state persistently */ - def getListState[T](stateName: String): ListState[T] + def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] /** Function to return queryInfo for currently running task */ def getQueryInfo(): QueryInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index b6ed48dab579..d0be62293d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreErrors} @@ -28,17 +29,20 @@ import org.apache.spark.sql.streaming.ListState * * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition + * @param keyEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored in the list */ class ListStateImpl[S]( store: StateStore, stateName: String, - keyExprEnc: ExpressionEncoder[Any]) + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S]) extends ListState[S] with Logging { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName) + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 0, VALUE_ROW_SCHEMA, useMultipleValuesPerKey = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 15d77030d57b..36758eafa392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming -import org.apache.commons.lang3.SerializationUtils - +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.StateStoreErrors import org.apache.spark.sql.types.{BinaryType, StructType} @@ -41,17 +41,27 @@ object StateKeyValueRowSchema { * * @param keySerializer - serializer to serialize the grouping key of type `GK` * to an [[InternalRow]] + * @param valEncoder - SQL encoder for value of type `S` * @param stateName - name of logical state partition * @tparam GK - grouping key type + * @tparam V - value type */ -class StateTypesEncoder[GK]( +class StateTypesEncoder[GK, V]( keySerializer: Serializer[GK], + valEncoder: Encoder[V], stateName: String) { import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._ + /** Variables reused for conversions between byte array and UnsafeRow */ private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA) private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA) + /** Variables reused for value conversions between spark sql and object */ + private val valExpressionEnc = encoderFor(valEncoder) + private val objToRowSerializer = valExpressionEnc.createSerializer() + private val rowToObjDeserializer = valExpressionEnc.resolveAndBind().createDeserializer() + private val reuseRow = new UnsafeRow(valEncoder.schema.fields.length) + // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. def encodeGroupingKey(): UnsafeRow = { @@ -66,23 +76,26 @@ class StateTypesEncoder[GK]( keyRow } - def encodeValue[S](value: S): UnsafeRow = { - val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable]) - val valueRow = valueProjection(InternalRow(valueByteArr)) - valueRow + def encodeValue(value: V): UnsafeRow = { + val objRow: InternalRow = objToRowSerializer.apply(value) + val bytes = objRow.asInstanceOf[UnsafeRow].getBytes() + val valRow = valueProjection(InternalRow(bytes)) + valRow } - def decodeValue[S](row: UnsafeRow): S = { - SerializationUtils - .deserialize(row.getBinary(0)) - .asInstanceOf[S] + def decodeValue(row: UnsafeRow): V = { + val bytes = row.getBinary(0) + reuseRow.pointTo(bytes, bytes.length) + val value = rowToObjDeserializer.apply(reuseRow) + value } } object StateTypesEncoder { - def apply[GK]( + def apply[GK, V]( keySerializer: Serializer[GK], - stateName: String): StateTypesEncoder[GK] = { - new StateTypesEncoder[GK](keySerializer, stateName) + valEncoder: Encoder[V], + stateName: String): StateTypesEncoder[GK, V] = { + new StateTypesEncoder[GK, V](keySerializer, valEncoder, stateName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 56a325a31e33..d5dd9fcaf401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -20,6 +20,7 @@ import java.util.UUID import org.apache.spark.TaskContext import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.{ListState, QueryInfo, StatefulProcessorHandle, ValueState} @@ -112,10 +113,10 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[T](stateName: String): ValueState[T] = { + override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + "initialization is complete") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder) + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) resultState } @@ -132,10 +133,10 @@ class StatefulProcessorHandleImpl( store.removeColFamilyIfExists(stateName) } - override def getListState[T](stateName: String): ListState[T] = { + override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + "initialization is complete") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder) + val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) resultState } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index a94a49d88325..5d2b9881c78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} @@ -29,16 +30,18 @@ import org.apache.spark.sql.streaming.ValueState * @param store - reference to the StateStore instance to be used for storing state * @param stateName - name of logical state partition * @param keyEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, - keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging { + keyExprEnc: ExpressionEncoder[Any], + valEncoder: Encoder[S]) extends ValueState[S] with Logging { private val keySerializer = keyExprEnc.createSerializer() - private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName) + private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 0, VALUE_ROW_SCHEMA) @@ -57,7 +60,7 @@ class ValueStateImpl[S]( override def get(): S = { val retRow = getImpl() if (retRow != null) { - stateTypesEncoder.decodeValue[S](retRow) + stateTypesEncoder.decodeValue(retRow) } else { null.asInstanceOf[S] } diff --git a/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java b/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java new file mode 100644 index 000000000000..0ba75f7789d0 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/execution/streaming/state/POJOTestClass.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state; + +/** + * A POJO class used for tests of arbitrary state SQL encoder. + */ +public class POJOTestClass { + // Fields + private String name; + private int id; + + // Constructors + public POJOTestClass() { + // Default constructor + } + + public POJOTestClass(String name, int id) { + this.name = name; + this.id = id; + } + + // Getter methods + public String getName() { + return name; + } + + public int getId() { + return id; + } + + // Setter methods + public void setName(String name) { + this.name = name; + } + + public void setId(int id) { + this.id = id; + } + + // Additional methods if needed + public void incrementId() { + id++; + System.out.println(name + " is now " + id + "!"); + } + + // Override toString for better representation + @Override + public String toString() { + return "POJOTestClass{" + + "name='" + name + '\'' + + ", age=" + id + + '}'; + } + + // Override equals and hashCode for custom equality + @Override + public boolean equals(Object obj) { + POJOTestClass testObj = (POJOTestClass) obj; + return id == testObj.id && name.equals(testObj.name); + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index be77f7a887c7..71462cb4b643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -33,6 +33,9 @@ import org.apache.spark.sql.streaming.ValueState import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +/** A case class for SQL encoder test purpose */ +case class TestClass(var id: Long, var name: String) + /** * Class that adds tests for single value ValueState types used in arbitrary stateful * operators such as transformWithState @@ -93,7 +96,7 @@ class ValueStateSuite extends SharedSparkSession Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) val stateName = "testState" - val testState: ValueState[Long] = handle.getValueState[Long]("testState") + val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) val ex = intercept[Exception] { testState.update(123) @@ -136,7 +139,7 @@ class ValueStateSuite extends SharedSparkSession val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) - val testState: ValueState[Long] = handle.getValueState[Long]("testState") + val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") testState.update(123) assert(testState.get() === 123) @@ -162,8 +165,10 @@ class ValueStateSuite extends SharedSparkSession val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) - val testState1: ValueState[Long] = handle.getValueState[Long]("testState1") - val testState2: ValueState[Long] = handle.getValueState[Long]("testState2") + val testState1: ValueState[Long] = handle.getValueState[Long]( + "testState1", Encoders.scalaLong) + val testState2: ValueState[Long] = handle.getValueState[Long]( + "testState2", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") testState1.update(123) assert(testState1.get() === 123) @@ -217,4 +222,108 @@ class ValueStateSuite extends SharedSparkSession matchPVals = true ) } + + test("test SQL encoder - Value state operations for Primitive(Double) instances") { + tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + + val testState: ValueState[Double] = handle.getValueState[Double]("testState", + Encoders.scalaDouble) + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + testState.update(1.0) + assert(testState.get().equals(1.0)) + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + + testState.update(2.0) + assert(testState.get().equals(2.0)) + testState.update(3.0) + assert(testState.get().equals(3.0)) + + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + } + } + + test("test SQL encoder - Value state operations for Primitive(Long) instances") { + tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + + val testState: ValueState[Long] = handle.getValueState[Long]("testState", + Encoders.scalaLong) + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + testState.update(1L) + assert(testState.get().equals(1L)) + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + + testState.update(2L) + assert(testState.get().equals(2L)) + testState.update(3L) + assert(testState.get().equals(3L)) + + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + } + } + + test("test SQL encoder - Value state operations for case class instances") { + tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + + val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", + Encoders.product[TestClass]) + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + testState.update(TestClass(1, "testcase1")) + assert(testState.get().equals(TestClass(1, "testcase1"))) + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + + testState.update(TestClass(2, "testcase2")) + assert(testState.get() === TestClass(2, "testcase2")) + testState.update(TestClass(3, "testcase3")) + assert(testState.get() === TestClass(3, "testcase3")) + + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + } + } + + test("test SQL encoder - Value state operations for POJO instances") { + tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + + val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", + Encoders.bean(classOf[POJOTestClass])) + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + testState.update(new POJOTestClass("testcase1", 1)) + assert(testState.get().equals(new POJOTestClass("testcase1", 1))) + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + + testState.update(new POJOTestClass("testcase2", 2)) + assert(testState.get().equals(new POJOTestClass("testcase2", 2))) + testState.update(new POJOTestClass("testcase3", 3)) + assert(testState.get().equals(new POJOTestClass("testcase3", 3))) + + testState.clear() + assert(!testState.exists()) + assert(testState.get() === null) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 3d085da4ab58..9572f7006f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf @@ -30,7 +31,7 @@ class TestListStateProcessor @transient var _listState: ListState[String] = _ override def init(outputMode: OutputMode): Unit = { - _listState = getHandle.getListState("testListState") + _listState = getHandle.getListState("testListState", Encoders.STRING) } override def handleInputRows( @@ -86,8 +87,8 @@ class ToggleSaveAndEmitProcessor @transient var _valueState: ValueState[Boolean] = _ override def init(outputMode: OutputMode): Unit = { - _listState = getHandle.getListState("testListState") - _valueState = getHandle.getValueState("testValueState") + _listState = getHandle.getListState("testListState", Encoders.STRING) + _valueState = getHandle.getValueState("testValueState", Encoders.scalaBoolean) } override def handleInputRows( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 8a87472a023a..cc8c64c94c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.Encoders import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StateStoreMultipleColumnFamiliesNotSupportedException} import org.apache.spark.sql.internal.SQLConf @@ -32,7 +33,7 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S @transient private var _countState: ValueState[Long] = _ override def init(outputMode: OutputMode): Unit = { - _countState = getHandle.getValueState[Long]("countState") + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) } override def handleInputRows( @@ -59,8 +60,8 @@ class RunningCountMostRecentStatefulProcessor @transient private var _mostRecent: ValueState[String] = _ override def init(outputMode: OutputMode): Unit = { - _countState = getHandle.getValueState[Long]("countState") - _mostRecent = getHandle.getValueState[String]("mostRecent") + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) + _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } override def handleInputRows( key: String, @@ -88,7 +89,7 @@ class MostRecentStatefulProcessorWithDeletion override def init(outputMode: OutputMode): Unit = { getHandle.deleteIfExists("countState") - _mostRecent = getHandle.getValueState[String]("mostRecent") + _mostRecent = getHandle.getValueState[String]("mostRecent", Encoders.STRING) } override def handleInputRows( @@ -116,7 +117,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess inputRows: Iterator[String], timerValues: TimerValues): Iterator[(String, String)] = { // Trying to create value state here should fail - _tempState = getHandle.getValueState[Long]("tempState") + _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong) Iterator.empty } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org