This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new becfb94e1c71 [SPARK-46865][SS] Add Batch Support for TransformWithState Operator becfb94e1c71 is described below commit becfb94e1c713d10dac83300d096be490a912fd2 Author: Eric Marnadi <eric.marn...@databricks.com> AuthorDate: Thu Feb 8 12:15:20 2024 +0900 [SPARK-46865][SS] Add Batch Support for TransformWithState Operator ### What changes were proposed in this pull request? We are allowing batch queries to use and define the `TransformWithState` operator, which was initially introduced for streaming. ### Why are the changes needed? This is needed to keep up the parity between streaming and batch APIs, since we want everything supported in streaming to be supported in batch, as well. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests that use the TransformWithState operator with a batch query. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44884 from ericm-db/tws-batch. Lead-authored-by: Eric Marnadi <eric.marn...@databricks.com> Co-authored-by: ericm-db <132308037+ericm...@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../analysis/UnsupportedOperationChecker.scala | 3 - .../spark/sql/execution/SparkStrategies.scala | 9 +- .../execution/streaming/IncrementalExecution.scala | 2 +- .../streaming/StatefulProcessorHandleImpl.scala | 25 ++-- .../streaming/TransformWithStateExec.scala | 138 ++++++++++++++++++--- .../sql/streaming/TransformWithStateSuite.scala | 29 ++--- 6 files changed, 151 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 15a856b273ed..d57464fcefc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -43,9 +43,6 @@ object UnsupportedOperationChecker extends Logging { throwError("dropDuplicatesWithinWatermark is not supported with batch " + "DataFrames/DataSets")(d) - case t: TransformWithState => - throwError("transformWithState is not supported with batch DataFrames/Datasets")(t) - case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f5c2f17f8826..65347fc9d237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -723,7 +723,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Strategy to convert [[TransformWithState]] logical operator to physical operator * in streaming plans. */ - object TransformWithStateStrategy extends Strategy { + object StreamingTransformWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case TransformWithState( keyDeserializer, valueDeserializer, groupingAttributes, @@ -892,6 +892,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, + dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder, + outputObjAttr, child) => + TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, + groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + keyEncoder, outputObjAttr, planLater(child)) :: Nil + case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-40443): support applyInPandasWithState in batch query throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 08d41b840d04..4469d52618e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -73,7 +73,7 @@ class IncrementalExecution( StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: - TransformWithStateStrategy :: Nil + StreamingTransformWithStateStrategy :: Nil } private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index d06938ffeafb..fed18fc7e458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -69,29 +69,28 @@ class QueryInfoImpl( * @param store - instance of state store * @param runId - unique id for the current run * @param keyEncoder - encoder for the key + * @param isStreaming - defines whether the query is streaming or batch */ class StatefulProcessorHandleImpl( store: StateStore, runId: UUID, - keyEncoder: ExpressionEncoder[Any]) + keyEncoder: ExpressionEncoder[Any], + isStreaming: Boolean = true) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" private def buildQueryInfo(): QueryInfo = { - val taskCtxOpt = Option(TaskContext.get()) - // Task context is not available in tests, so we generate a random query id and batch id here - val queryId = if (taskCtxOpt.isDefined) { - taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY) - } else { - assert(Utils.isTesting, "Failed to find query id in task context") - UUID.randomUUID().toString - } - val batchId = if (taskCtxOpt.isDefined) { - taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong + val taskCtxOpt = Option(TaskContext.get()) + val (queryId, batchId) = if (!isStreaming) { + (BATCH_QUERY_ID, 0L) + } else if (taskCtxOpt.isDefined) { + (taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY), + taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong) } else { - assert(Utils.isTesting, "Failed to find batch id in task context") - 0 + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + (UUID.randomUUID().toString, 0L) } new QueryInfoImpl(UUID.fromString(queryId), runId, batchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 82e827685b47..818bef5f34a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD @@ -25,9 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode} import org.apache.spark.sql.types._ -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** * Physical operator for executing `TransformWithState` @@ -44,6 +46,7 @@ import org.apache.spark.util.CompletionIterator * @param batchTimestampMs processing timestamp of the current batch. * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param isStreaming defines whether the query is streaming or batch * @param child the physical plan for the underlying data */ case class TransformWithStateExec( @@ -60,7 +63,8 @@ case class TransformWithStateExec( batchTimestampMs: Option[Long], eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], - child: SparkPlan) + child: SparkPlan, + isStreaming: Boolean = true) extends UnaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec { override def shortName: String = "transformWithStateExec" @@ -143,7 +147,11 @@ case class TransformWithStateExec( // by the upstream (consumer) operators in addition to the processing in this operator. allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) commitTimeMs += timeTakenMs { - store.commit() + if (isStreaming) { + store.commit() + } else { + store.abort() + } } setStoreMetrics(store) setOperatorMetrics() @@ -155,23 +163,113 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateInfo, - schemaForKeyRow, - schemaForValueRow, - numColsPrefixKey = 0, - session.sqlContext.sessionState, - Some(session.sqlContext.streams.stateStoreCoordinator), - useColumnFamilies = true - ) { - case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder) - assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) - statefulProcessor.init(processorHandle, outputMode) - processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) - val result = processDataWithPartition(singleIterator, store, processorHandle) - result + if (isStreaming) { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateInfo, + schemaForKeyRow, + schemaForValueRow, + numColsPrefixKey = 0, + session.sqlContext.sessionState, + Some(session.sqlContext.streams.stateStoreCoordinator), + useColumnFamilies = true + ) { + case (store: StateStore, singleIterator: Iterator[InternalRow]) => + processData(store, singleIterator) + } + } else { + // If the query is running in batch mode, we need to create a new StateStore and instantiate + // a temp directory on the executors in mapPartitionsWithIndex. + val broadcastedHadoopConf = + new SerializableConfiguration(session.sessionState.newHadoopConf()) + child.execute().mapPartitionsWithIndex[InternalRow]( + (i, iter) => { + val providerId = { + val tempDirPath = Utils.createTempDir().getAbsolutePath + new StateStoreProviderId( + StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId) + } + + val sqlConf = new SQLConf() + sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + val storeConf = new StateStoreConf(sqlConf) + + // Create StateStoreProvider for this partition + val stateStoreProvider = StateStoreProvider.createAndInit( + providerId, + schemaForKeyRow, + schemaForValueRow, + numColsPrefixKey = 0, + useColumnFamilies = true, + storeConf = storeConf, + hadoopConf = broadcastedHadoopConf.value) + + val store = stateStoreProvider.getStore(0) + val outputIterator = processData(store, iter) + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator.iterator, { + stateStoreProvider.close() + statefulProcessor.close() + }) + } + ) } } + + /** + * Process the data in the partition using the state store and the stateful processor. + * @param store The state store to use + * @param singleIterator The iterator of rows to process + * @return An iterator of rows that are the result of processing the input rows + */ + private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): + CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val processorHandle = new StatefulProcessorHandleImpl( + store, getStateInfo.queryRunId, keyEncoder, isStreaming) + assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) + statefulProcessor.init(processorHandle, outputMode) + processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + processDataWithPartition(singleIterator, store, processorHandle) + } +} + +object TransformWithStateExec { + + // Plan logical transformWithState for batch queries + def generateSparkPlanForBatchQueries( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + timeoutMode: TimeoutMode, + outputMode: OutputMode, + keyEncoder: ExpressionEncoder[Any], + outputObjAttr: Attribute, + child: SparkPlan): SparkPlan = { + val shufflePartitions = child.session.sessionState.conf.numShufflePartitions + val statefulOperatorStateInfo = StatefulOperatorStateInfo( + checkpointLocation = "", // empty checkpointLocation will be populated in doExecute + queryRunId = UUID.randomUUID(), + operatorId = 0, + storeVersion = 0, + numPartitions = shufflePartitions + ) + + new TransformWithStateExec( + keyDeserializer, + valueDeserializer, + groupingAttributes, + dataAttributes, + statefulProcessor, + timeoutMode, + outputMode, + keyEncoder, + outputObjAttr, + Some(statefulOperatorStateInfo), + Some(System.currentTimeMillis), + None, + None, + child, + isStreaming = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 569e6852315c..7b448ac93419 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StateStoreMultipleColumnFamiliesNotSupportedException} import org.apache.spark.sql.internal.SQLConf @@ -196,6 +195,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - batch should succeed") { + val inputData = Seq("a", "b") + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeoutMode.NoTimeouts(), + OutputMode.Append()) + + val df = result.toDF() + checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) + } + test("transformWithState - test deleteIfExists operator") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -333,22 +344,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest class TransformWithStateValidationSuite extends StateStoreMetricsTest { import testImplicits._ - test("transformWithState - batch should fail") { - val ex = intercept[Exception] { - val df = Seq("a", "a", "b").toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor, - TimeoutMode.NoTimeouts(), - OutputMode.Append()) - .write - .format("noop") - .mode(SaveMode.Append) - .save() - } - assert(ex.isInstanceOf[AnalysisException]) - assert(ex.getMessage.contains("not supported")) - } - test("transformWithState - streaming with hdfsStateStoreProvider should fail") { val inputData = MemoryStream[String] val result = inputData.toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org