This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new becfb94e1c71 [SPARK-46865][SS] Add Batch Support for 
TransformWithState Operator
becfb94e1c71 is described below

commit becfb94e1c713d10dac83300d096be490a912fd2
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Thu Feb 8 12:15:20 2024 +0900

    [SPARK-46865][SS] Add Batch Support for TransformWithState Operator
    
    ### What changes were proposed in this pull request?
    
    We are allowing batch queries to use and define the `TransformWithState` 
operator, which was initially introduced for streaming.
    
    ### Why are the changes needed?
    
    This is needed to keep up the parity between streaming and batch APIs, 
since we want everything supported in streaming to be supported in batch, as 
well.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit tests that use the TransformWithState operator with a batch 
query.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44884 from ericm-db/tws-batch.
    
    Lead-authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Co-authored-by: ericm-db <132308037+ericm...@users.noreply.github.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../analysis/UnsupportedOperationChecker.scala     |   3 -
 .../spark/sql/execution/SparkStrategies.scala      |   9 +-
 .../execution/streaming/IncrementalExecution.scala |   2 +-
 .../streaming/StatefulProcessorHandleImpl.scala    |  25 ++--
 .../streaming/TransformWithStateExec.scala         | 138 ++++++++++++++++++---
 .../sql/streaming/TransformWithStateSuite.scala    |  29 ++---
 6 files changed, 151 insertions(+), 55 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 15a856b273ed..d57464fcefc0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -43,9 +43,6 @@ object UnsupportedOperationChecker extends Logging {
         throwError("dropDuplicatesWithinWatermark is not supported with batch 
" +
           "DataFrames/DataSets")(d)
 
-      case t: TransformWithState =>
-        throwError("transformWithState is not supported with batch 
DataFrames/Datasets")(t)
-
       case _ =>
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f5c2f17f8826..65347fc9d237 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -723,7 +723,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
    * Strategy to convert [[TransformWithState]] logical operator to physical 
operator
    * in streaming plans.
    */
-  object TransformWithStateStrategy extends Strategy {
+  object StreamingTransformWithStateStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case TransformWithState(
         keyDeserializer, valueDeserializer, groupingAttributes,
@@ -892,6 +892,13 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
           hasInitialState, planLater(initialState), planLater(child)
         ) :: Nil
+      case logical.TransformWithState(keyDeserializer, valueDeserializer, 
groupingAttributes,
+          dataAttributes, statefulProcessor, timeoutMode, outputMode, 
keyEncoder,
+          outputObjAttr, child) =>
+        
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, 
valueDeserializer,
+          groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, 
outputMode,
+          keyEncoder, outputObjAttr, planLater(child)) :: Nil
+
       case _: FlatMapGroupsInPandasWithState =>
         // TODO(SPARK-40443): support applyInPandasWithState in batch query
         throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 08d41b840d04..4469d52618e8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -73,7 +73,7 @@ class IncrementalExecution(
       StreamingRelationStrategy ::
       StreamingDeduplicationStrategy ::
       StreamingGlobalLimitStrategy(outputMode) ::
-      TransformWithStateStrategy :: Nil
+      StreamingTransformWithStateStrategy :: Nil
   }
 
   private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index d06938ffeafb..fed18fc7e458 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -69,29 +69,28 @@ class QueryInfoImpl(
  * @param store - instance of state store
  * @param runId - unique id for the current run
  * @param keyEncoder - encoder for the key
+ * @param isStreaming - defines whether the query is streaming or batch
  */
 class StatefulProcessorHandleImpl(
     store: StateStore,
     runId: UUID,
-    keyEncoder: ExpressionEncoder[Any])
+    keyEncoder: ExpressionEncoder[Any],
+    isStreaming: Boolean = true)
   extends StatefulProcessorHandle with Logging {
   import StatefulProcessorHandleState._
 
+  private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"
   private def buildQueryInfo(): QueryInfo = {
-    val taskCtxOpt = Option(TaskContext.get())
-    // Task context is not available in tests, so we generate a random query 
id and batch id here
-    val queryId = if (taskCtxOpt.isDefined) {
-      taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
-    } else {
-      assert(Utils.isTesting, "Failed to find query id in task context")
-      UUID.randomUUID().toString
-    }
 
-    val batchId = if (taskCtxOpt.isDefined) {
-      taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
+    val taskCtxOpt = Option(TaskContext.get())
+    val (queryId, batchId) = if (!isStreaming) {
+      (BATCH_QUERY_ID, 0L)
+    } else if (taskCtxOpt.isDefined) {
+      (taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY),
+        
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong)
     } else {
-      assert(Utils.isTesting, "Failed to find batch id in task context")
-      0
+      assert(Utils.isTesting, "Failed to find query id/batch Id in task 
context")
+      (UUID.randomUUID().toString, 0L)
     }
 
     new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 82e827685b47..818bef5f34a2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit.NANOSECONDS
 
 import org.apache.spark.rdd.RDD
@@ -25,9 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, 
Attribute, Expressi
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, 
TimeoutMode}
 import org.apache.spark.sql.types._
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
 /**
  * Physical operator for executing `TransformWithState`
@@ -44,6 +46,7 @@ import org.apache.spark.util.CompletionIterator
  * @param batchTimestampMs processing timestamp of the current batch.
  * @param eventTimeWatermarkForLateEvents event time watermark for filtering 
late events
  * @param eventTimeWatermarkForEviction event time watermark for state eviction
+ * @param isStreaming defines whether the query is streaming or batch
  * @param child the physical plan for the underlying data
  */
 case class TransformWithStateExec(
@@ -60,7 +63,8 @@ case class TransformWithStateExec(
     batchTimestampMs: Option[Long],
     eventTimeWatermarkForLateEvents: Option[Long],
     eventTimeWatermarkForEviction: Option[Long],
-    child: SparkPlan)
+    child: SparkPlan,
+    isStreaming: Boolean = true)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport with 
ObjectProducerExec {
 
   override def shortName: String = "transformWithStateExec"
@@ -143,7 +147,11 @@ case class TransformWithStateExec(
       // by the upstream (consumer) operators in addition to the processing in 
this operator.
       allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
       commitTimeMs += timeTakenMs {
-        store.commit()
+        if (isStreaming) {
+          store.commit()
+        } else {
+          store.abort()
+        }
       }
       setStoreMetrics(store)
       setOperatorMetrics()
@@ -155,23 +163,113 @@ case class TransformWithStateExec(
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
-    child.execute().mapPartitionsWithStateStore[InternalRow](
-      getStateInfo,
-      schemaForKeyRow,
-      schemaForValueRow,
-      numColsPrefixKey = 0,
-      session.sqlContext.sessionState,
-      Some(session.sqlContext.streams.stateStoreCoordinator),
-      useColumnFamilies = true
-    ) {
-      case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-        val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
-          keyEncoder)
-        assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
-        statefulProcessor.init(processorHandle, outputMode)
-        
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
-        val result = processDataWithPartition(singleIterator, store, 
processorHandle)
-        result
+    if (isStreaming) {
+      child.execute().mapPartitionsWithStateStore[InternalRow](
+        getStateInfo,
+        schemaForKeyRow,
+        schemaForValueRow,
+        numColsPrefixKey = 0,
+        session.sqlContext.sessionState,
+        Some(session.sqlContext.streams.stateStoreCoordinator),
+        useColumnFamilies = true
+      ) {
+        case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+          processData(store, singleIterator)
+      }
+    } else {
+      // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
+      // a temp directory on the executors in mapPartitionsWithIndex.
+      val broadcastedHadoopConf =
+        new SerializableConfiguration(session.sessionState.newHadoopConf())
+      child.execute().mapPartitionsWithIndex[InternalRow](
+        (i, iter) => {
+          val providerId = {
+            val tempDirPath = Utils.createTempDir().getAbsolutePath
+            new StateStoreProviderId(
+              StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId)
+          }
+
+          val sqlConf = new SQLConf()
+          sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+            classOf[RocksDBStateStoreProvider].getName)
+          val storeConf = new StateStoreConf(sqlConf)
+
+          // Create StateStoreProvider for this partition
+          val stateStoreProvider = StateStoreProvider.createAndInit(
+            providerId,
+            schemaForKeyRow,
+            schemaForValueRow,
+            numColsPrefixKey = 0,
+            useColumnFamilies = true,
+            storeConf = storeConf,
+            hadoopConf = broadcastedHadoopConf.value)
+
+          val store = stateStoreProvider.getStore(0)
+          val outputIterator = processData(store, iter)
+          CompletionIterator[InternalRow, 
Iterator[InternalRow]](outputIterator.iterator, {
+            stateStoreProvider.close()
+            statefulProcessor.close()
+          })
+        }
+      )
     }
   }
+
+  /**
+   * Process the data in the partition using the state store and the stateful 
processor.
+   * @param store The state store to use
+   * @param singleIterator The iterator of rows to process
+   * @return An iterator of rows that are the result of processing the input 
rows
+   */
+  private def processData(store: StateStore, singleIterator: 
Iterator[InternalRow]):
+    CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+    val processorHandle = new StatefulProcessorHandleImpl(
+      store, getStateInfo.queryRunId, keyEncoder, isStreaming)
+    assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
+    statefulProcessor.init(processorHandle, outputMode)
+    processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+    processDataWithPartition(singleIterator, store, processorHandle)
+  }
+}
+
+object TransformWithStateExec {
+
+  // Plan logical transformWithState for batch queries
+  def generateSparkPlanForBatchQueries(
+      keyDeserializer: Expression,
+      valueDeserializer: Expression,
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      statefulProcessor: StatefulProcessor[Any, Any, Any],
+      timeoutMode: TimeoutMode,
+      outputMode: OutputMode,
+      keyEncoder: ExpressionEncoder[Any],
+      outputObjAttr: Attribute,
+      child: SparkPlan): SparkPlan = {
+    val shufflePartitions = 
child.session.sessionState.conf.numShufflePartitions
+    val statefulOperatorStateInfo = StatefulOperatorStateInfo(
+      checkpointLocation = "", // empty checkpointLocation will be populated 
in doExecute
+      queryRunId = UUID.randomUUID(),
+      operatorId = 0,
+      storeVersion = 0,
+      numPartitions = shufflePartitions
+    )
+
+    new TransformWithStateExec(
+      keyDeserializer,
+      valueDeserializer,
+      groupingAttributes,
+      dataAttributes,
+      statefulProcessor,
+      timeoutMode,
+      outputMode,
+      keyEncoder,
+      outputObjAttr,
+      Some(statefulOperatorStateInfo),
+      Some(System.currentTimeMillis),
+      None,
+      None,
+      child,
+      isStreaming = false)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 569e6852315c..7b448ac93419 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
 
 import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, SaveMode}
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, 
StateStoreMultipleColumnFamiliesNotSupportedException}
 import org.apache.spark.sql.internal.SQLConf
@@ -196,6 +195,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest
     }
   }
 
+  test("transformWithState - batch should succeed") {
+    val inputData = Seq("a", "b")
+    val result = inputData.toDS()
+      .groupByKey(x => x)
+      .transformWithState(new RunningCountStatefulProcessor(),
+        TimeoutMode.NoTimeouts(),
+        OutputMode.Append())
+
+    val df = result.toDF()
+    checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF())
+  }
+
   test("transformWithState - test deleteIfExists operator") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
@@ -333,22 +344,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
 class TransformWithStateValidationSuite extends StateStoreMetricsTest {
   import testImplicits._
 
-  test("transformWithState - batch should fail") {
-    val ex = intercept[Exception] {
-      val df = Seq("a", "a", "b").toDS()
-        .groupByKey(x => x)
-        .transformWithState(new RunningCountStatefulProcessor,
-          TimeoutMode.NoTimeouts(),
-          OutputMode.Append())
-        .write
-        .format("noop")
-        .mode(SaveMode.Append)
-        .save()
-    }
-    assert(ex.isInstanceOf[AnalysisException])
-    assert(ex.getMessage.contains("not supported"))
-  }
-
   test("transformWithState - streaming with hdfsStateStoreProvider should 
fail") {
     val inputData = MemoryStream[String]
     val result = inputData.toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to