This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6b5917beff30 [SPARK-46961][SS] Using ProcessorContext to store and retrieve handle 6b5917beff30 is described below commit 6b5917beff30c813a362584a135a587001df1390 Author: Eric Marnadi <eric.marn...@databricks.com> AuthorDate: Mon Mar 4 21:20:23 2024 +0300 [SPARK-46961][SS] Using ProcessorContext to store and retrieve handle ### What changes were proposed in this pull request? Setting the processorHandle as a part of the statefulProcessor, so that the user doesn't have to explicitly keep track of it, and can instead simply call `getStatefulProcessorHandle` ### Why are the changes needed? This enhances the usability of the State API ### Does this PR introduce _any_ user-facing change? Yes, this is an API change. This enhances usability of the StatefulProcessorHandle and the TransformWithState operator. ### How was this patch tested? Existing unit tests are sufficient ### Was this patch authored or co-authored using generative AI tooling? No Closes #45359 from ericm-db/handle-context. Authored-by: Eric Marnadi <eric.marn...@databricks.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../src/main/resources/error/error-classes.json | 7 +++ docs/sql-error-conditions.md | 7 +++ .../apache/spark/sql/errors/ExecutionErrors.scala | 6 +++ .../spark/sql/streaming/StatefulProcessor.scala | 38 ++++++++++++--- .../streaming/TransformWithStateExec.scala | 4 +- .../streaming/TransformWithListStateSuite.scala | 14 ++---- .../sql/streaming/TransformWithStateSuite.scala | 54 ++++++++++------------ 7 files changed, 84 insertions(+), 46 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 6ccd841ccd0f..7cf3e9c533ca 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3337,6 +3337,13 @@ ], "sqlState" : "42802" }, + "STATE_STORE_HANDLE_NOT_INITIALIZED" : { + "message" : [ + "The handle has not been initialized for this StatefulProcessor.", + "Please only use the StatefulProcessor within the transformWithState operator." + ], + "sqlState" : "42802" + }, "STATE_STORE_MULTIPLE_VALUES_PER_KEY" : { "message" : [ "Store does not support multiple values per key" diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index f026c456eb2d..7be01f8cb513 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2091,6 +2091,13 @@ Star (*) is not allowed in a select list when GROUP BY an ordinal position is us Failed to remove default column family with reserved name=`<colFamilyName>`. +### STATE_STORE_HANDLE_NOT_INITIALIZED + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +The handle has not been initialized for this StatefulProcessor. +Please only use the StatefulProcessor within the transformWithState operator. + ### STATE_STORE_MULTIPLE_VALUES_PER_KEY [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index b74a67b49bda..7910c386fcf1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -53,6 +53,12 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { e) } + def stateStoreHandleNotInitialized(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED", + messageParameters = Map.empty) + } + def failToRecognizePatternAfterUpgradeError( pattern: String, e: Throwable): SparkUpgradeException = { new SparkUpgradeException( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 76794136dd49..42a9430bf39d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.errors.ExecutionErrors /** * Represents the arbitrary stateful logic that needs to be provided by the user to perform @@ -29,17 +30,18 @@ import org.apache.spark.annotation.{Evolving, Experimental} @Evolving private[sql] trait StatefulProcessor[K, I, O] extends Serializable { + /** + * Handle to the stateful processor that provides access to the state store and other + * stateful processing related APIs. + */ + private var statefulProcessorHandle: StatefulProcessorHandle = null + /** * Function that will be invoked as the first method that allows for users to * initialize all their state variables and perform other init actions before handling data. - * @param handle - reference to the statefulProcessorHandle that the user can use to perform - * actions like creating state variables, accessing queryInfo etc. Please refer to - * [[StatefulProcessorHandle]] for more details. * @param outputMode - output mode for the stateful processor */ - def init( - handle: StatefulProcessorHandle, - outputMode: OutputMode): Unit + def init(outputMode: OutputMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key @@ -59,5 +61,27 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { * Function called as the last method that allows for users to perform * any cleanup or teardown operations. */ - def close (): Unit + def close (): Unit = {} + + /** + * Function to set the stateful processor handle that will be used to interact with the state + * store and other stateful processor related operations. + * + * @param handle - instance of StatefulProcessorHandle + */ + final def setHandle(handle: StatefulProcessorHandle): Unit = { + statefulProcessorHandle = handle + } + + /** + * Function to get the stateful processor handle that will be used to interact with the state + * + * @return handle - instance of StatefulProcessorHandle + */ + final def getHandle: StatefulProcessorHandle = { + if (statefulProcessorHandle == null) { + throw ExecutionErrors.stateStoreHandleNotInitialized() + } + statefulProcessorHandle + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 5a80fb1209ba..117bc722f09e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -156,6 +156,7 @@ case class TransformWithStateExec( setStoreMetrics(store) setOperatorMetrics() statefulProcessor.close() + statefulProcessor.setHandle(null) processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) }) } @@ -228,7 +229,8 @@ case class TransformWithStateExec( val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, isStreaming) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) - statefulProcessor.init(processorHandle, outputMode) + statefulProcessor.setHandle(processorHandle) + statefulProcessor.init(outputMode) processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) processDataWithPartition(singleIterator, store, processorHandle) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index f7ed813badde..3d085da4ab58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -27,12 +27,10 @@ case class InputRow(key: String, action: String, value: String) class TestListStateProcessor extends StatefulProcessor[String, InputRow, (String, String)] { - @transient var _processorHandle: StatefulProcessorHandle = _ @transient var _listState: ListState[String] = _ - override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = { - _processorHandle = handle - _listState = handle.getListState("testListState") + override def init(outputMode: OutputMode): Unit = { + _listState = getHandle.getListState("testListState") } override def handleInputRows( @@ -84,14 +82,12 @@ class TestListStateProcessor class ToggleSaveAndEmitProcessor extends StatefulProcessor[String, String, String] { - @transient var _processorHandle: StatefulProcessorHandle = _ @transient var _listState: ListState[String] = _ @transient var _valueState: ValueState[Boolean] = _ - override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = { - _processorHandle = handle - _listState = handle.getListState("testListState") - _valueState = handle.getValueState("testValueState") + override def init(outputMode: OutputMode): Unit = { + _listState = getHandle.getListState("testListState") + _valueState = getHandle.getValueState("testValueState") } override def handleInputRows( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index a4a04e0b5077..8a87472a023a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StateStoreMultipleColumnFamiliesNotSupportedException} @@ -30,14 +30,9 @@ object TransformWithStateSuiteUtils { class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient private var _countState: ValueState[Long] = _ - @transient var _processorHandle: StatefulProcessorHandle = _ - - override def init( - handle: StatefulProcessorHandle, - outputMode: OutputMode) : Unit = { - _processorHandle = handle - assert(handle.getQueryInfo().getBatchId >= 0) - _countState = _processorHandle.getValueState[Long]("countState") + + override def init(outputMode: OutputMode): Unit = { + _countState = getHandle.getValueState[Long]("countState") } override def handleInputRows( @@ -62,17 +57,11 @@ class RunningCountMostRecentStatefulProcessor with Logging { @transient private var _countState: ValueState[Long] = _ @transient private var _mostRecent: ValueState[String] = _ - @transient var _processorHandle: StatefulProcessorHandle = _ - - override def init( - handle: StatefulProcessorHandle, - outputMode: OutputMode) : Unit = { - _processorHandle = handle - assert(handle.getQueryInfo().getBatchId >= 0) - _countState = _processorHandle.getValueState[Long]("countState") - _mostRecent = _processorHandle.getValueState[String]("mostRecent") - } + override def init(outputMode: OutputMode): Unit = { + _countState = getHandle.getValueState[Long]("countState") + _mostRecent = getHandle.getValueState[String]("mostRecent") + } override def handleInputRows( key: String, inputRows: Iterator[(String, String)], @@ -96,15 +85,10 @@ class MostRecentStatefulProcessorWithDeletion extends StatefulProcessor[String, (String, String), (String, String)] with Logging { @transient private var _mostRecent: ValueState[String] = _ - @transient var _processorHandle: StatefulProcessorHandle = _ - - override def init( - handle: StatefulProcessorHandle, - outputMode: OutputMode) : Unit = { - _processorHandle = handle - assert(handle.getQueryInfo().getBatchId >= 0) - _processorHandle.deleteIfExists("countState") - _mostRecent = _processorHandle.getValueState[String]("mostRecent") + + override def init(outputMode: OutputMode): Unit = { + getHandle.deleteIfExists("countState") + _mostRecent = getHandle.getValueState[String]("mostRecent") } override def handleInputRows( @@ -132,7 +116,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess inputRows: Iterator[String], timerValues: TimerValues): Iterator[(String, String)] = { // Trying to create value state here should fail - _tempState = _processorHandle.getValueState[Long]("tempState") + _tempState = getHandle.getValueState[Long]("tempState") Iterator.empty } } @@ -195,6 +179,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("Use statefulProcessor without transformWithState - handle should be absent") { + val processor = new RunningCountStatefulProcessor() + val ex = intercept[Exception] { + processor.getHandle + } + checkError( + ex.asInstanceOf[SparkRuntimeException], + errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED", + parameters = Map.empty + ) + } + test("transformWithState - batch should succeed") { val inputData = Seq("a", "b") val result = inputData.toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org