This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e610d1d8f79b [SPARK-46852][SS] Remove use of explicit key encoder and pass it implicitly to the operator for transformWithState operator e610d1d8f79b is described below commit e610d1d8f79b913cb9ee9236a6325202c58d8397 Author: Anish Shrigondekar <anish.shrigonde...@databricks.com> AuthorDate: Thu Feb 1 22:31:07 2024 +0900 [SPARK-46852][SS] Remove use of explicit key encoder and pass it implicitly to the operator for transformWithState operator ### What changes were proposed in this pull request? Remove use of explicit key encoder and pass it implicitly to the operator for transformWithState operator ### Why are the changes needed? Changes needed to avoid asking users to provide explicit key encoder and we also might need them for subsequent timer related changes ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Existing unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #44974 from anishshri-db/task/SPARK-46852. Authored-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/streaming/StatefulProcessorHandle.scala | 5 +---- .../spark/sql/catalyst/plans/logical/object.scala | 3 +++ .../spark/sql/execution/SparkStrategies.scala | 3 ++- .../streaming/StatefulProcessorHandleImpl.scala | 13 +++++++++---- .../streaming/TransformWithStateExec.scala | 6 +++++- .../sql/execution/streaming/ValueStateImpl.scala | 12 +++++------- .../streaming/state/ValueStateSuite.scala | 22 +++++++++++----------- .../sql/streaming/TransformWithStateSuite.scala | 8 +++----- 8 files changed, 39 insertions(+), 33 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 302de4a3c947..5eaccceb947c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} -import org.apache.spark.sql.Encoder /** * Represents the operation handle provided to the stateful processor used in the @@ -34,12 +33,10 @@ private[sql] trait StatefulProcessorHandle extends Serializable { * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * @param stateName - name of the state variable - * @param keyEncoder - Spark SQL Encoder for key - * @tparam K - type of key * @tparam T - type of state variable * @return - instance of ValueState of type T that can be used to store state persistently */ - def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): ValueState[T] + def getValueState[T](stateName: String): ValueState[T] /** Function to return queryInfo for currently running task */ def getQueryInfo(): QueryInfo diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 8f937dd5a777..cb8673d20ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -577,6 +577,7 @@ object TransformWithState { timeoutMode: TimeoutMode, outputMode: OutputMode, child: LogicalPlan): LogicalPlan = { + val keyEncoder = encoderFor[K] val mapped = new TransformWithState( UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), @@ -585,6 +586,7 @@ object TransformWithState { statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]], timeoutMode, outputMode, + keyEncoder.asInstanceOf[ExpressionEncoder[Any]], CatalystSerde.generateObjAttr[U], child ) @@ -600,6 +602,7 @@ case class TransformWithState( statefulProcessor: StatefulProcessor[Any, Any, Any], timeoutMode: TimeoutMode, outputMode: OutputMode, + keyEncoder: ExpressionEncoder[Any], outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5d4063d125c8..f5c2f17f8826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -728,7 +728,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, - outputAttr, child) => + keyEncoder, outputAttr, child) => val execPlan = TransformWithStateExec( keyDeserializer, valueDeserializer, @@ -737,6 +737,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { statefulProcessor, timeoutMode, outputMode, + keyEncoder, outputAttr, stateInfo = None, batchTimestampMs = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 758e8c646ffc..d0cd8f7dc0a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -20,7 +20,7 @@ import java.util.UUID import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle, ValueState} import org.apache.spark.util.Utils @@ -67,8 +67,13 @@ class QueryInfoImpl( * Class that provides a concrete implementation of a StatefulProcessorHandle. Note that we keep * track of valid transitions as various functions are invoked to track object lifecycle. * @param store - instance of state store + * @param runId - unique id for the current run + * @param keyEncoder - encoder for the key */ -class StatefulProcessorHandleImpl(store: StateStore, runId: UUID) +class StatefulProcessorHandleImpl( + store: StateStore, + runId: UUID, + keyEncoder: ExpressionEncoder[Any]) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ @@ -108,11 +113,11 @@ class StatefulProcessorHandleImpl(store: StateStore, runId: UUID) def getHandleState: StatefulProcessorHandleState = currState - override def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): ValueState[T] = { + override def getValueState[T](stateName: String): ValueState[T] = { verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + "initialization is complete") store.createColFamilyIfAbsent(stateName) - val resultState = new ValueStateImpl[K, T](store, stateName, keyEncoder) + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder) resultState } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index ce651d959afc..82e827685b47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ @@ -38,6 +39,7 @@ import org.apache.spark.util.CompletionIterator * @param statefulProcessor processor methods called on underlying data * @param timeoutMode defines the timeout mode * @param outputMode defines the output mode for the statefulProcessor + * @param keyEncoder expression encoder for the key type * @param outputObjAttr Defines the output object * @param batchTimestampMs processing timestamp of the current batch. * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events @@ -52,6 +54,7 @@ case class TransformWithStateExec( statefulProcessor: StatefulProcessor[Any, Any, Any], timeoutMode: TimeoutMode, outputMode: OutputMode, + keyEncoder: ExpressionEncoder[Any], outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], batchTimestampMs: Option[Long], @@ -162,7 +165,8 @@ case class TransformWithStateExec( useColumnFamilies = true ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId) + val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, + keyEncoder) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.init(processorHandle, outputMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 91554de97fe3..5a1b6d01baa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -21,9 +21,8 @@ import java.io.Serializable import org.apache.commons.lang3.SerializationUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.ValueState @@ -38,10 +37,10 @@ import org.apache.spark.sql.types._ * @tparam K - data type of key * @tparam S - data type of object that will be stored */ -class ValueStateImpl[K, S]( +class ValueStateImpl[S]( store: StateStore, stateName: String, - keyEnc: Encoder[K]) extends ValueState[S] with Logging { + keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging { // TODO: validate places that are trying to encode the key and check if we can eliminate/ // add caching for some of these calls. @@ -52,10 +51,9 @@ class ValueStateImpl[K, S]( s"stateName=$stateName") } - val exprEnc: ExpressionEncoder[K] = encoderFor(keyEnc) - val toRow = exprEnc.createSerializer() + val toRow = keyExprEnc.createSerializer() val keyByteArr = toRow - .apply(keyOption.get.asInstanceOf[K]).asInstanceOf[UnsafeRow].getBytes() + .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes() val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) val keyEncoder = UnsafeProjection.create(schemaForKeyRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 6d929498d65b..49a5fff131ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ValueState @@ -87,10 +88,10 @@ class ValueStateSuite extends SharedSparkSession test("Implicit key operations") { tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID()) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) - val testState: ValueState[Long] = handle.getValueState[String, Long]("testState", - Encoders.STRING) + val testState: ValueState[Long] = handle.getValueState[Long]("testState") assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty) val ex = intercept[Exception] { testState.update(123) @@ -118,10 +119,10 @@ class ValueStateSuite extends SharedSparkSession test("Value state operations for single instance") { tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID()) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) - val testState: ValueState[Long] = handle.getValueState[String, Long]("testState", - Encoders.STRING) + val testState: ValueState[Long] = handle.getValueState[Long]("testState") ImplicitGroupingKeyTracker.setImplicitKey("test_key") testState.update(123) assert(testState.get() === 123) @@ -144,12 +145,11 @@ class ValueStateSuite extends SharedSparkSession test("Value state operations for multiple instances") { tryWithProviderResource(newStoreProviderWithValueState(true)) { provider => val store = provider.getStore(0) - val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID()) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) - val testState1: ValueState[Long] = handle.getValueState[String, Long]("testState1", - Encoders.STRING) - val testState2: ValueState[Long] = handle.getValueState[String, Long]("testState2", - Encoders.STRING) + val testState1: ValueState[Long] = handle.getValueState[Long]("testState1") + val testState2: ValueState[Long] = handle.getValueState[Long]("testState2") ImplicitGroupingKeyTracker.setImplicitKey("test_key") testState1.update(123) assert(testState1.get() === 123) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 9909919c0cae..70a71f745066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Encoders, SaveMode} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf @@ -38,8 +38,7 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S outputMode: OutputMode) : Unit = { _processorHandle = handle assert(handle.getQueryInfo().getBatchId >= 0) - _countState = _processorHandle.getValueState[String, Long]("countState", - Encoders.STRING) + _countState = _processorHandle.getValueState[Long]("countState") } override def handleInputRows( @@ -67,8 +66,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess inputRows: Iterator[String], timerValues: TimerValues): Iterator[(String, String)] = { // Trying to create value state here should fail - _tempState = _processorHandle.getValueState[String, Long]("tempState", - Encoders.STRING) + _tempState = _processorHandle.getValueState[Long]("tempState") Iterator.empty } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org