This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b5917beff30 [SPARK-46961][SS] Using ProcessorContext to store and 
retrieve handle
6b5917beff30 is described below

commit 6b5917beff30c813a362584a135a587001df1390
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Mon Mar 4 21:20:23 2024 +0300

    [SPARK-46961][SS] Using ProcessorContext to store and retrieve handle
    
    ### What changes were proposed in this pull request?
    
    Setting the processorHandle as a part of the statefulProcessor, so that the 
user doesn't have to explicitly keep track of it, and can instead simply call 
`getStatefulProcessorHandle`
    
    ### Why are the changes needed?
    
    This enhances the usability of the State API
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this is an API change. This enhances usability of the 
StatefulProcessorHandle and the TransformWithState operator.
    
    ### How was this patch tested?
    
    Existing unit tests are sufficient
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45359 from ericm-db/handle-context.
    
    Authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |  7 +++
 docs/sql-error-conditions.md                       |  7 +++
 .../apache/spark/sql/errors/ExecutionErrors.scala  |  6 +++
 .../spark/sql/streaming/StatefulProcessor.scala    | 38 ++++++++++++---
 .../streaming/TransformWithStateExec.scala         |  4 +-
 .../streaming/TransformWithListStateSuite.scala    | 14 ++----
 .../sql/streaming/TransformWithStateSuite.scala    | 54 ++++++++++------------
 7 files changed, 84 insertions(+), 46 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 6ccd841ccd0f..7cf3e9c533ca 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3337,6 +3337,13 @@
     ],
     "sqlState" : "42802"
   },
+  "STATE_STORE_HANDLE_NOT_INITIALIZED" : {
+    "message" : [
+      "The handle has not been initialized for this StatefulProcessor.",
+      "Please only use the StatefulProcessor within the transformWithState 
operator."
+    ],
+    "sqlState" : "42802"
+  },
   "STATE_STORE_MULTIPLE_VALUES_PER_KEY" : {
     "message" : [
       "Store does not support multiple values per key"
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index f026c456eb2d..7be01f8cb513 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2091,6 +2091,13 @@ Star (*) is not allowed in a select list when GROUP BY 
an ordinal position is us
 
 Failed to remove default column family with reserved name=`<colFamilyName>`.
 
+### STATE_STORE_HANDLE_NOT_INITIALIZED
+
+[SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+The handle has not been initialized for this StatefulProcessor.
+Please only use the StatefulProcessor within the transformWithState operator.
+
 ### STATE_STORE_MULTIPLE_VALUES_PER_KEY
 
 [SQLSTATE: 
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index b74a67b49bda..7910c386fcf1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -53,6 +53,12 @@ private[sql] trait ExecutionErrors extends 
DataTypeErrorsBase {
       e)
   }
 
+  def stateStoreHandleNotInitialized(): SparkRuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
+      messageParameters = Map.empty)
+  }
+
   def failToRecognizePatternAfterUpgradeError(
       pattern: String, e: Throwable): SparkUpgradeException = {
     new SparkUpgradeException(
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 76794136dd49..42a9430bf39d 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
 import java.io.Serializable
 
 import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.errors.ExecutionErrors
 
 /**
  * Represents the arbitrary stateful logic that needs to be provided by the 
user to perform
@@ -29,17 +30,18 @@ import org.apache.spark.annotation.{Evolving, Experimental}
 @Evolving
 private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
 
+  /**
+   * Handle to the stateful processor that provides access to the state store 
and other
+   * stateful processing related APIs.
+   */
+  private var statefulProcessorHandle: StatefulProcessorHandle = null
+
   /**
    * Function that will be invoked as the first method that allows for users to
    * initialize all their state variables and perform other init actions 
before handling data.
-   * @param handle - reference to the statefulProcessorHandle that the user 
can use to perform
-   *               actions like creating state variables, accessing queryInfo 
etc. Please refer to
-   *               [[StatefulProcessorHandle]] for more details.
    * @param outputMode - output mode for the stateful processor
    */
-  def init(
-      handle: StatefulProcessorHandle,
-      outputMode: OutputMode): Unit
+  def init(outputMode: OutputMode): Unit
 
   /**
    * Function that will allow users to interact with input data rows along 
with the grouping key
@@ -59,5 +61,27 @@ private[sql] trait StatefulProcessor[K, I, O] extends 
Serializable {
    * Function called as the last method that allows for users to perform
    * any cleanup or teardown operations.
    */
-  def close (): Unit
+  def close (): Unit = {}
+
+  /**
+   * Function to set the stateful processor handle that will be used to 
interact with the state
+   * store and other stateful processor related operations.
+   *
+   * @param handle - instance of StatefulProcessorHandle
+   */
+  final def setHandle(handle: StatefulProcessorHandle): Unit = {
+    statefulProcessorHandle = handle
+  }
+
+  /**
+   * Function to get the stateful processor handle that will be used to 
interact with the state
+   *
+   * @return handle - instance of StatefulProcessorHandle
+   */
+  final def getHandle: StatefulProcessorHandle = {
+    if (statefulProcessorHandle == null) {
+      throw ExecutionErrors.stateStoreHandleNotInitialized()
+    }
+    statefulProcessorHandle
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 5a80fb1209ba..117bc722f09e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -156,6 +156,7 @@ case class TransformWithStateExec(
       setStoreMetrics(store)
       setOperatorMetrics()
       statefulProcessor.close()
+      statefulProcessor.setHandle(null)
       processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
     })
   }
@@ -228,7 +229,8 @@ case class TransformWithStateExec(
     val processorHandle = new StatefulProcessorHandleImpl(
       store, getStateInfo.queryRunId, keyEncoder, isStreaming)
     assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
-    statefulProcessor.init(processorHandle, outputMode)
+    statefulProcessor.setHandle(processorHandle)
+    statefulProcessor.init(outputMode)
     processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
     processDataWithPartition(singleIterator, store, processorHandle)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index f7ed813badde..3d085da4ab58 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -27,12 +27,10 @@ case class InputRow(key: String, action: String, value: 
String)
 class TestListStateProcessor
   extends StatefulProcessor[String, InputRow, (String, String)] {
 
-  @transient var _processorHandle: StatefulProcessorHandle = _
   @transient var _listState: ListState[String] = _
 
-  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
-    _processorHandle = handle
-    _listState = handle.getListState("testListState")
+  override def init(outputMode: OutputMode): Unit = {
+    _listState = getHandle.getListState("testListState")
   }
 
   override def handleInputRows(
@@ -84,14 +82,12 @@ class TestListStateProcessor
 class ToggleSaveAndEmitProcessor
   extends StatefulProcessor[String, String, String] {
 
-  @transient var _processorHandle: StatefulProcessorHandle = _
   @transient var _listState: ListState[String] = _
   @transient var _valueState: ValueState[Boolean] = _
 
-  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
-    _processorHandle = handle
-    _listState = handle.getListState("testListState")
-    _valueState = handle.getValueState("testValueState")
+  override def init(outputMode: OutputMode): Unit = {
+    _listState = getHandle.getListState("testListState")
+    _valueState = getHandle.getValueState("testValueState")
   }
 
   override def handleInputRows(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index a4a04e0b5077..8a87472a023a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.streaming
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkRuntimeException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, 
StateStoreMultipleColumnFamiliesNotSupportedException}
@@ -30,14 +30,9 @@ object TransformWithStateSuiteUtils {
 class RunningCountStatefulProcessor extends StatefulProcessor[String, String, 
(String, String)]
   with Logging {
   @transient private var _countState: ValueState[Long] = _
-  @transient var _processorHandle: StatefulProcessorHandle = _
-
-  override def init(
-      handle: StatefulProcessorHandle,
-      outputMode: OutputMode) : Unit = {
-    _processorHandle = handle
-    assert(handle.getQueryInfo().getBatchId >= 0)
-    _countState = _processorHandle.getValueState[Long]("countState")
+
+  override def init(outputMode: OutputMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState")
   }
 
   override def handleInputRows(
@@ -62,17 +57,11 @@ class RunningCountMostRecentStatefulProcessor
   with Logging {
   @transient private var _countState: ValueState[Long] = _
   @transient private var _mostRecent: ValueState[String] = _
-  @transient var _processorHandle: StatefulProcessorHandle = _
-
-  override def init(
-      handle: StatefulProcessorHandle,
-      outputMode: OutputMode) : Unit = {
-    _processorHandle = handle
-    assert(handle.getQueryInfo().getBatchId >= 0)
-    _countState = _processorHandle.getValueState[Long]("countState")
-    _mostRecent = _processorHandle.getValueState[String]("mostRecent")
-  }
 
+  override def init(outputMode: OutputMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState")
+    _mostRecent = getHandle.getValueState[String]("mostRecent")
+  }
   override def handleInputRows(
       key: String,
       inputRows: Iterator[(String, String)],
@@ -96,15 +85,10 @@ class MostRecentStatefulProcessorWithDeletion
   extends StatefulProcessor[String, (String, String), (String, String)]
   with Logging {
   @transient private var _mostRecent: ValueState[String] = _
-  @transient var _processorHandle: StatefulProcessorHandle = _
-
-  override def init(
-       handle: StatefulProcessorHandle,
-       outputMode: OutputMode) : Unit = {
-    _processorHandle = handle
-    assert(handle.getQueryInfo().getBatchId >= 0)
-    _processorHandle.deleteIfExists("countState")
-    _mostRecent = _processorHandle.getValueState[String]("mostRecent")
+
+  override def init(outputMode: OutputMode): Unit = {
+    getHandle.deleteIfExists("countState")
+    _mostRecent = getHandle.getValueState[String]("mostRecent")
   }
 
   override def handleInputRows(
@@ -132,7 +116,7 @@ class RunningCountStatefulProcessorWithError extends 
RunningCountStatefulProcess
       inputRows: Iterator[String],
       timerValues: TimerValues): Iterator[(String, String)] = {
     // Trying to create value state here should fail
-    _tempState = _processorHandle.getValueState[Long]("tempState")
+    _tempState = getHandle.getValueState[Long]("tempState")
     Iterator.empty
   }
 }
@@ -195,6 +179,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest
     }
   }
 
+  test("Use statefulProcessor without transformWithState - handle should be 
absent") {
+    val processor = new RunningCountStatefulProcessor()
+    val ex = intercept[Exception] {
+      processor.getHandle
+    }
+    checkError(
+      ex.asInstanceOf[SparkRuntimeException],
+      errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
+      parameters = Map.empty
+    )
+  }
+
   test("transformWithState - batch should succeed") {
     val inputData = Seq("a", "b")
     val result = inputData.toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to