This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e610d1d8f79b [SPARK-46852][SS] Remove use of explicit key encoder and 
pass it implicitly to the operator for transformWithState operator
e610d1d8f79b is described below

commit e610d1d8f79b913cb9ee9236a6325202c58d8397
Author: Anish Shrigondekar <anish.shrigonde...@databricks.com>
AuthorDate: Thu Feb 1 22:31:07 2024 +0900

    [SPARK-46852][SS] Remove use of explicit key encoder and pass it implicitly 
to the operator for transformWithState operator
    
    ### What changes were proposed in this pull request?
    Remove use of explicit key encoder and pass it implicitly to the operator 
for transformWithState operator
    
    ### Why are the changes needed?
    Changes needed to avoid asking users to provide explicit key encoder and we 
also might need them for subsequent timer related changes
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    Existing unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44974 from anishshri-db/task/SPARK-46852.
    
    Authored-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/streaming/StatefulProcessorHandle.scala    |  5 +----
 .../spark/sql/catalyst/plans/logical/object.scala  |  3 +++
 .../spark/sql/execution/SparkStrategies.scala      |  3 ++-
 .../streaming/StatefulProcessorHandleImpl.scala    | 13 +++++++++----
 .../streaming/TransformWithStateExec.scala         |  6 +++++-
 .../sql/execution/streaming/ValueStateImpl.scala   | 12 +++++-------
 .../streaming/state/ValueStateSuite.scala          | 22 +++++++++++-----------
 .../sql/streaming/TransformWithStateSuite.scala    |  8 +++-----
 8 files changed, 39 insertions(+), 33 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 302de4a3c947..5eaccceb947c 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
 import java.io.Serializable
 
 import org.apache.spark.annotation.{Evolving, Experimental}
-import org.apache.spark.sql.Encoder
 
 /**
  * Represents the operation handle provided to the stateful processor used in 
the
@@ -34,12 +33,10 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
    * @param stateName - name of the state variable
-   * @param keyEncoder - Spark SQL Encoder for key
-   * @tparam K - type of key
    * @tparam T - type of state variable
    * @return - instance of ValueState of type T that can be used to store 
state persistently
    */
-  def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): 
ValueState[T]
+  def getValueState[T](stateName: String): ValueState[T]
 
   /** Function to return queryInfo for currently running task */
   def getQueryInfo(): QueryInfo
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 8f937dd5a777..cb8673d20ed3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -577,6 +577,7 @@ object TransformWithState {
       timeoutMode: TimeoutMode,
       outputMode: OutputMode,
       child: LogicalPlan): LogicalPlan = {
+    val keyEncoder = encoderFor[K]
     val mapped = new TransformWithState(
       UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
       UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
@@ -585,6 +586,7 @@ object TransformWithState {
       statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
       timeoutMode,
       outputMode,
+      keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
       CatalystSerde.generateObjAttr[U],
       child
     )
@@ -600,6 +602,7 @@ case class TransformWithState(
     statefulProcessor: StatefulProcessor[Any, Any, Any],
     timeoutMode: TimeoutMode,
     outputMode: OutputMode,
+    keyEncoder: ExpressionEncoder[Any],
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer {
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5d4063d125c8..f5c2f17f8826 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -728,7 +728,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case TransformWithState(
         keyDeserializer, valueDeserializer, groupingAttributes,
         dataAttributes, statefulProcessor, timeoutMode, outputMode,
-        outputAttr, child) =>
+        keyEncoder, outputAttr, child) =>
         val execPlan = TransformWithStateExec(
           keyDeserializer,
           valueDeserializer,
@@ -737,6 +737,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           statefulProcessor,
           timeoutMode,
           outputMode,
+          keyEncoder,
           outputAttr,
           stateInfo = None,
           batchTimestampMs = None,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 758e8c646ffc..d0cd8f7dc0a3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -20,7 +20,7 @@ import java.util.UUID
 
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle, 
ValueState}
 import org.apache.spark.util.Utils
@@ -67,8 +67,13 @@ class QueryInfoImpl(
  * Class that provides a concrete implementation of a StatefulProcessorHandle. 
Note that we keep
  * track of valid transitions as various functions are invoked to track object 
lifecycle.
  * @param store - instance of state store
+ * @param runId - unique id for the current run
+ * @param keyEncoder - encoder for the key
  */
-class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
+class StatefulProcessorHandleImpl(
+    store: StateStore,
+    runId: UUID,
+    keyEncoder: ExpressionEncoder[Any])
   extends StatefulProcessorHandle with Logging {
   import StatefulProcessorHandleState._
 
@@ -108,11 +113,11 @@ class StatefulProcessorHandleImpl(store: StateStore, 
runId: UUID)
 
   def getHandleState: StatefulProcessorHandleState = currState
 
-  override def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): 
ValueState[T] = {
+  override def getValueState[T](stateName: String): ValueState[T] = {
     verify(currState == CREATED, s"Cannot create state variable with 
name=$stateName after " +
       "initialization is complete")
     store.createColFamilyIfAbsent(stateName)
-    val resultState = new ValueStateImpl[K, T](store, stateName, keyEncoder)
+    val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
     resultState
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index ce651d959afc..82e827685b47 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, SortOrder, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution._
@@ -38,6 +39,7 @@ import org.apache.spark.util.CompletionIterator
  * @param statefulProcessor processor methods called on underlying data
  * @param timeoutMode defines the timeout mode
  * @param outputMode defines the output mode for the statefulProcessor
+ * @param keyEncoder expression encoder for the key type
  * @param outputObjAttr Defines the output object
  * @param batchTimestampMs processing timestamp of the current batch.
  * @param eventTimeWatermarkForLateEvents event time watermark for filtering 
late events
@@ -52,6 +54,7 @@ case class TransformWithStateExec(
     statefulProcessor: StatefulProcessor[Any, Any, Any],
     timeoutMode: TimeoutMode,
     outputMode: OutputMode,
+    keyEncoder: ExpressionEncoder[Any],
     outputObjAttr: Attribute,
     stateInfo: Option[StatefulOperatorStateInfo],
     batchTimestampMs: Option[Long],
@@ -162,7 +165,8 @@ case class TransformWithStateExec(
       useColumnFamilies = true
     ) {
       case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-        val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId)
+        val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
+          keyEncoder)
         assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
         statefulProcessor.init(processorHandle, outputMode)
         
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index 91554de97fe3..5a1b6d01baa3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -21,9 +21,8 @@ import java.io.Serializable
 import org.apache.commons.lang3.SerializationUtils
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.ValueState
@@ -38,10 +37,10 @@ import org.apache.spark.sql.types._
  * @tparam K - data type of key
  * @tparam S - data type of object that will be stored
  */
-class ValueStateImpl[K, S](
+class ValueStateImpl[S](
     store: StateStore,
     stateName: String,
-    keyEnc: Encoder[K]) extends ValueState[S] with Logging {
+    keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {
 
   // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
   // add caching for some of these calls.
@@ -52,10 +51,9 @@ class ValueStateImpl[K, S](
         s"stateName=$stateName")
     }
 
-    val exprEnc: ExpressionEncoder[K] = encoderFor(keyEnc)
-    val toRow = exprEnc.createSerializer()
+    val toRow = keyExprEnc.createSerializer()
     val keyByteArr = toRow
-      .apply(keyOption.get.asInstanceOf[K]).asInstanceOf[UnsafeRow].getBytes()
+      .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
 
     val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
     val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 6d929498d65b..49a5fff131ae 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.ValueState
@@ -87,10 +88,10 @@ class ValueStateSuite extends SharedSparkSession
   test("Implicit key operations") {
     tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
       val store = provider.getStore(0)
-      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
-      val testState: ValueState[Long] = handle.getValueState[String, 
Long]("testState",
-        Encoders.STRING)
+      val testState: ValueState[Long] = handle.getValueState[Long]("testState")
       assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
       val ex = intercept[Exception] {
         testState.update(123)
@@ -118,10 +119,10 @@ class ValueStateSuite extends SharedSparkSession
   test("Value state operations for single instance") {
     tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
       val store = provider.getStore(0)
-      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
-      val testState: ValueState[Long] = handle.getValueState[String, 
Long]("testState",
-        Encoders.STRING)
+      val testState: ValueState[Long] = handle.getValueState[Long]("testState")
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(123)
       assert(testState.get() === 123)
@@ -144,12 +145,11 @@ class ValueStateSuite extends SharedSparkSession
   test("Value state operations for multiple instances") {
     tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
       val store = provider.getStore(0)
-      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
 
-      val testState1: ValueState[Long] = handle.getValueState[String, 
Long]("testState1",
-        Encoders.STRING)
-      val testState2: ValueState[Long] = handle.getValueState[String, 
Long]("testState2",
-        Encoders.STRING)
+      val testState1: ValueState[Long] = 
handle.getValueState[Long]("testState1")
+      val testState2: ValueState[Long] = 
handle.getValueState[Long]("testState2")
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState1.update(123)
       assert(testState1.get() === 123)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 9909919c0cae..70a71f745066 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming
 
 import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Encoders, SaveMode}
+import org.apache.spark.sql.{AnalysisException, SaveMode}
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
@@ -38,8 +38,7 @@ class RunningCountStatefulProcessor extends 
StatefulProcessor[String, String, (S
       outputMode: OutputMode) : Unit = {
     _processorHandle = handle
     assert(handle.getQueryInfo().getBatchId >= 0)
-    _countState = _processorHandle.getValueState[String, Long]("countState",
-      Encoders.STRING)
+    _countState = _processorHandle.getValueState[Long]("countState")
   }
 
   override def handleInputRows(
@@ -67,8 +66,7 @@ class RunningCountStatefulProcessorWithError extends 
RunningCountStatefulProcess
       inputRows: Iterator[String],
       timerValues: TimerValues): Iterator[(String, String)] = {
     // Trying to create value state here should fail
-    _tempState = _processorHandle.getValueState[String, Long]("tempState",
-      Encoders.STRING)
+    _tempState = _processorHandle.getValueState[Long]("tempState")
     Iterator.empty
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to