This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4771853c9bc [SPARK-44398][CONNECT] Scala foreachBatch API
4771853c9bc is described below

commit 4771853c9bc26b8741091d63d77c4b6487e74189
Author: Raghu Angadi <raghu.ang...@databricks.com>
AuthorDate: Thu Jul 13 10:47:49 2023 -0700

    [SPARK-44398][CONNECT] Scala foreachBatch API
    
    This implements Scala foreachBatch(). The implementation basic and needs 
some more enhancements. The server side will be shared by Python implementation 
as well.
    
    One notable hack in this PR is that it runs user's `foreachBatch()` with 
regular(legacy) DataFrame, rather than setting up remote Spark connect session 
and connect DataFrame.
    
    ### Why are the changes needed?
    Adds foreachBatch() support in Scala Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Adds foreachBatch() API
    
    ### How was this patch tested?
    - A simple unit test.
    
    Closes #41969 from rangadi/feb-scala.
    
    Authored-by: Raghu Angadi <raghu.ang...@databricks.com>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 .../spark/sql/streaming/DataStreamWriter.scala     | 28 ++++++-
 .../spark/sql/streaming/StreamingQuerySuite.scala  | 52 ++++++++++++-
 .../src/main/protobuf/spark/connect/commands.proto | 11 +--
 .../sql/connect/planner/SparkConnectPlanner.scala  | 25 +++++-
 .../planner/StreamingForeachBatchHelper.scala      | 69 +++++++++++++++++
 python/pyspark/sql/connect/proto/commands_pb2.py   | 88 +++++++++++-----------
 python/pyspark/sql/connect/proto/commands_pb2.pyi  | 46 +++++++----
 python/pyspark/sql/connect/streaming/readwriter.py |  4 +-
 8 files changed, 251 insertions(+), 72 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 9f63f68a000..ad76ab4a1bc 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -30,12 +30,15 @@ import org.apache.spark.connect.proto.Command
 import org.apache.spark.connect.proto.WriteStreamOperationStart
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Dataset, ForeachWriter}
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.connect.common.ForeachWriterPacket
 import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
 import org.apache.spark.sql.execution.streaming.ContinuousTrigger
 import org.apache.spark.sql.execution.streaming.OneTimeTrigger
 import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
+import org.apache.spark.sql.types.NullType
 import org.apache.spark.util.SparkSerDeUtils
+import org.apache.spark.util.Utils
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems 
(e.g. file systems,
@@ -218,7 +221,30 @@ final class DataStreamWriter[T] private[sql] (ds: 
Dataset[T]) extends Logging {
     val scalaWriterBuilder = proto.ScalarScalaUDF
       .newBuilder()
       .setPayload(ByteString.copyFrom(serialized))
-    sinkBuilder.getForeachWriterBuilder.setScalaWriter(scalaWriterBuilder)
+    sinkBuilder.getForeachWriterBuilder.setScalaFunction(scalaWriterBuilder)
+    this
+  }
+
+  /**
+   * :: Experimental ::
+   *
+   * (Scala-specific) Sets the output of the streaming query to be processed 
using the provided
+   * function. This is supported only in the micro-batch execution modes (that 
is, when the
+   * trigger is not continuous). In every micro-batch, the provided function 
will be called in
+   * every micro-batch with (i) the output rows as a Dataset and (ii) the 
batch identifier. The
+   * batchId can be used to deduplicate and transactionally write the output 
(that is, the
+   * provided Dataset) to external systems. The output Dataset is guaranteed 
to be exactly the
+   * same for the same batchId (assuming all operations are deterministic in 
the query).
+   *
+   * @since 3.5.0
+   */
+  @Evolving
+  def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] 
= {
+    val serializedFn = Utils.serialize(function)
+    sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
+      .setPayload(ByteString.copyFrom(serializedFn))
+      .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // 
Unused.
+      .setNullable(true) // Unused.
     this
   }
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 6ddcedf19cb..438e6e0c2fe 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -26,13 +26,14 @@ import org.scalatest.concurrent.Eventually.eventually
 import org.scalatest.concurrent.Futures.timeout
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.sql.{ForeachWriter, Row, SparkSession, SQLHelper}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, 
SQLHelper}
 import org.apache.spark.sql.connect.client.util.RemoteSparkSession
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.functions.window
 import org.apache.spark.util.Utils
 
-class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
+class StreamingQuerySuite extends RemoteSparkSession with SQLHelper with 
Logging {
 
   test("Streaming API with windowed aggregate query") {
     // This verifies standard streaming API by starting a streaming query with 
windowed count.
@@ -114,7 +115,7 @@ class StreamingQuerySuite extends RemoteSparkSession with 
SQLHelper {
     withSQLConf(
       "spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
     ) {
-      spark.sql("DROP TABLE IF EXISTS my_table")
+      spark.sql("DROP TABLE IF EXISTS my_table").collect()
 
       withTempPath { ckpt =>
         val q1 = spark.readStream
@@ -266,6 +267,42 @@ class StreamingQuerySuite extends RemoteSparkSession with 
SQLHelper {
     q.stop()
     assert(!q1.isActive)
   }
+
+  test("foreachBatch") {
+    // Starts a streaming query with a foreachBatch function, which writes 
batchId and row count
+    // to a temp view. The test verifies that the view is populated with data.
+
+    val viewName = "test_view"
+    val tableName = s"global_temp.$viewName"
+
+    withTable(tableName) {
+      val q = spark.readStream
+        .format("rate")
+        .option("rowsPerSecond", "10")
+        .option("numPartitions", "1")
+        .load()
+        .writeStream
+        .foreachBatch(new ForeachBatchFn(viewName))
+        .start()
+
+      eventually(timeout(30.seconds)) { // Wait for first progress.
+        assert(q.lastProgress != null)
+        assert(q.lastProgress.numInputRows > 0)
+      }
+
+      eventually(timeout(30.seconds)) {
+        // There should be row(s) in temporary view created by foreachBatch.
+        val rows = spark
+          .sql(s"select * from $tableName")
+          .collect()
+          .toSeq
+        assert(rows.size > 0)
+        log.info(s"Rows in $tableName: $rows")
+      }
+
+      q.stop()
+    }
+  }
 }
 
 class TestForeachWriter[T] extends ForeachWriter[T] {
@@ -292,3 +329,12 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
 case class TestClass(value: Int) {
   override def toString: String = value.toString
 }
+
+class ForeachBatchFn(val viewName: String) extends ((DataFrame, Long) => Unit) 
with Serializable {
+  override def apply(df: DataFrame, batchId: Long): Unit = {
+    val count = df.count()
+    df.sparkSession
+      .createDataFrame(Seq((batchId, count)))
+      .createOrReplaceGlobalTempView(viewName)
+  }
+}
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index ac97c2dc842..a5924747af9 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -214,13 +214,14 @@ message WriteStreamOperationStart {
     string table_name = 12;
   }
 
-  StreamingForeachWriter foreach_writer = 13;
+  StreamingForeachFunction foreach_writer = 13;
+  StreamingForeachFunction foreach_batch = 14;
 }
 
-message StreamingForeachWriter {
-  oneof writer {
-    PythonUDF python_writer = 1;
-    ScalarScalaUDF scala_writer = 2;
+message StreamingForeachFunction {
+  oneof function {
+    PythonUDF python_function = 1;
+    ScalarScalaUDF scala_function = 2;
   }
 }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 27b668bef6b..6b1f84ada51 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -33,6 +33,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand, 
StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, 
WriteStreamOperationStart, WriteStreamOperationStartResult}
 import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.connect.proto.Parse.ParseFormat
+import org.apache.spark.connect.proto.StreamingForeachFunction
 import org.apache.spark.connect.proto.StreamingQueryManagerCommand
 import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
 import 
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -2661,13 +2662,13 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     }
 
     if (writeOp.hasForeachWriter) {
-      if (writeOp.getForeachWriter.hasPythonWriter) {
-        val foreach = writeOp.getForeachWriter.getPythonWriter
+      if (writeOp.getForeachWriter.hasPythonFunction) {
+        val foreach = writeOp.getForeachWriter.getPythonFunction
         val pythonFcn = transformPythonFunction(foreach)
         writer.foreachImplementation(
           new PythonForeachWriter(pythonFcn, 
dataset.schema).asInstanceOf[ForeachWriter[Any]])
       } else {
-        val foreachWriterPkt = 
unpackForeachWriter(writeOp.getForeachWriter.getScalaWriter)
+        val foreachWriterPkt = 
unpackForeachWriter(writeOp.getForeachWriter.getScalaFunction)
         val clientWriter = foreachWriterPkt.foreachWriter
         val encoder: Option[ExpressionEncoder[Any]] = Try(
           ExpressionEncoder(
@@ -2676,6 +2677,24 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       }
     }
 
+    if (writeOp.hasForeachBatch) {
+      val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match {
+        case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION =>
+          throw InvalidPlanInput("Python ForeachBatch is not supported yet. 
WIP.")
+
+        case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
+          val scalaFn = 
Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
+            writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
+            Utils.getContextOrSparkClassLoader)
+          StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, 
sessionHolder)
+
+        case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
+          throw InvalidPlanInput("Unexpected")
+      }
+
+      writer.foreachBatch(foreachBatchFn)
+    }
+
     val query = writeOp.getPath match {
       case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
       case "" => writer.start()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
new file mode 100644
index 00000000000..66487e7048c
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.planner
+
+import java.util.UUID
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.service.SessionHolder
+
+/**
+ * A helper class for handling ForeachBatch related functionality in Spark 
Connect servers
+ */
+object StreamingForeachBatchHelper extends Logging {
+
+  type ForeachBatchFnType = (DataFrame, Long) => Unit
+
+  /**
+   * Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame 
cache so that the
+   * user function can access it. The cache is cleared once ForeachBatch 
returns.
+   */
+  def dataFrameCachingWrapper(
+      fn: ForeachBatchFnType,
+      sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, 
batchId: Long) =>
+    {
+      val dfId = UUID.randomUUID().toString
+      log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to 
the log.
+
+      // TODO: Sanity check there is no other active DataFrame for this query. 
Need to include
+      //       query id available in the cache for this check.
+
+      sessionHolder.cacheDataFrameById(dfId, df)
+      try {
+        fn(df, batchId)
+      } finally {
+        log.info(s"Removing DataFrame with id $dfId from the cache")
+        sessionHolder.removeCachedDataFrame(dfId)
+      }
+    }
+  }
+
+  /**
+   * Handles setting up Scala remote session and other Spark Connect 
environment and then runs the
+   * provided foreachBatch function `fn`.
+   *
+   * HACK ALERT: This version does not atually set up Spark connect. Directly 
passes the
+   * DataFrame, so the user code actually runs with legacy DataFrame.
+   */
+  def scalaForeachBatchWrapper(
+      fn: ForeachBatchFnType,
+      sessionHolder: SessionHolder): ForeachBatchFnType = {
+    // TODO: Set up Spark Connect session. Do we actually need this for the 
first version?
+    dataFrameCachingWrapper(fn, sessionHolder)
+  }
+}
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py 
b/python/pyspark/sql/connect/proto/commands_pb2.py
index af5fdd247ca..6d3cc22dcfa 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_rel
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
+    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -87,49 +87,49 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WRITEOPERATIONV2_MODE._serialized_start = 3164
     _WRITEOPERATIONV2_MODE._serialized_end = 3323
     _WRITESTREAMOPERATIONSTART._serialized_start = 3339
-    _WRITESTREAMOPERATIONSTART._serialized_end = 4059
+    _WRITESTREAMOPERATIONSTART._serialized_end = 4139
     _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 1944
     _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2002
-    _STREAMINGFOREACHWRITER._serialized_start = 4062
-    _STREAMINGFOREACHWRITER._serialized_end = 4229
-    _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4231
-    _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4352
-    _STREAMINGQUERYINSTANCEID._serialized_start = 4354
-    _STREAMINGQUERYINSTANCEID._serialized_end = 4419
-    _STREAMINGQUERYCOMMAND._serialized_start = 4422
-    _STREAMINGQUERYCOMMAND._serialized_end = 5054
-    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 4921
-    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 4965
-    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 4967
-    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5043
-    _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5057
-    _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6198
-    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5640
-    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 5810
-    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 5812
-    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5884
-    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5886
-    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 5925
-    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 5928
-    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6125
-    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 
6127
-    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6183
-    _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6201
-    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6551
-    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start 
= 6461
-    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 
6540
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6554
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 7277
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 6957
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7084
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7086
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 7201
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 7203
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
7262
-    _GETRESOURCESCOMMAND._serialized_start = 7279
-    _GETRESOURCESCOMMAND._serialized_end = 7300
-    _GETRESOURCESCOMMANDRESULT._serialized_start = 7303
-    _GETRESOURCESCOMMANDRESULT._serialized_end = 7515
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7419
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7515
+    _STREAMINGFOREACHFUNCTION._serialized_start = 4142
+    _STREAMINGFOREACHFUNCTION._serialized_end = 4321
+    _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4323
+    _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4444
+    _STREAMINGQUERYINSTANCEID._serialized_start = 4446
+    _STREAMINGQUERYINSTANCEID._serialized_end = 4511
+    _STREAMINGQUERYCOMMAND._serialized_start = 4514
+    _STREAMINGQUERYCOMMAND._serialized_end = 5146
+    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5013
+    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5057
+    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5059
+    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5135
+    _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5149
+    _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6290
+    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5732
+    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 5902
+    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 5904
+    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5976
+    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5978
+    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6017
+    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6020
+    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6217
+    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 
6219
+    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6275
+    _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6293
+    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6643
+    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start 
= 6553
+    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 
6632
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6646
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 7369
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7049
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7176
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7178
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 7293
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 7295
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
7354
+    _GETRESOURCESCOMMAND._serialized_start = 7371
+    _GETRESOURCESCOMMAND._serialized_end = 7392
+    _GETRESOURCESCOMMANDRESULT._serialized_start = 7395
+    _GETRESOURCESCOMMANDRESULT._serialized_end = 7607
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7511
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7607
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi 
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 3677af39fa1..2590139d301 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -714,6 +714,7 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
     PATH_FIELD_NUMBER: builtins.int
     TABLE_NAME_FIELD_NUMBER: builtins.int
     FOREACH_WRITER_FIELD_NUMBER: builtins.int
+    FOREACH_BATCH_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
         """(Required) The output of the `input` streaming relation will be 
written."""
@@ -738,7 +739,9 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
     path: builtins.str
     table_name: builtins.str
     @property
-    def foreach_writer(self) -> global___StreamingForeachWriter: ...
+    def foreach_writer(self) -> global___StreamingForeachFunction: ...
+    @property
+    def foreach_batch(self) -> global___StreamingForeachFunction: ...
     def __init__(
         self,
         *,
@@ -754,7 +757,8 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
         query_name: builtins.str = ...,
         path: builtins.str = ...,
         table_name: builtins.str = ...,
-        foreach_writer: global___StreamingForeachWriter | None = ...,
+        foreach_writer: global___StreamingForeachFunction | None = ...,
+        foreach_batch: global___StreamingForeachFunction | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -763,6 +767,8 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
             b"available_now",
             "continuous_checkpoint_interval",
             b"continuous_checkpoint_interval",
+            "foreach_batch",
+            b"foreach_batch",
             "foreach_writer",
             b"foreach_writer",
             "input",
@@ -788,6 +794,8 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
             b"available_now",
             "continuous_checkpoint_interval",
             b"continuous_checkpoint_interval",
+            "foreach_batch",
+            b"foreach_batch",
             "foreach_writer",
             b"foreach_writer",
             "format",
@@ -829,38 +837,48 @@ class 
WriteStreamOperationStart(google.protobuf.message.Message):
 
 global___WriteStreamOperationStart = WriteStreamOperationStart
 
-class StreamingForeachWriter(google.protobuf.message.Message):
+class StreamingForeachFunction(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-    PYTHON_WRITER_FIELD_NUMBER: builtins.int
-    SCALA_WRITER_FIELD_NUMBER: builtins.int
+    PYTHON_FUNCTION_FIELD_NUMBER: builtins.int
+    SCALA_FUNCTION_FIELD_NUMBER: builtins.int
     @property
-    def python_writer(self) -> 
pyspark.sql.connect.proto.expressions_pb2.PythonUDF: ...
+    def python_function(self) -> 
pyspark.sql.connect.proto.expressions_pb2.PythonUDF: ...
     @property
-    def scala_writer(self) -> 
pyspark.sql.connect.proto.expressions_pb2.ScalarScalaUDF: ...
+    def scala_function(self) -> 
pyspark.sql.connect.proto.expressions_pb2.ScalarScalaUDF: ...
     def __init__(
         self,
         *,
-        python_writer: pyspark.sql.connect.proto.expressions_pb2.PythonUDF | 
None = ...,
-        scala_writer: pyspark.sql.connect.proto.expressions_pb2.ScalarScalaUDF 
| None = ...,
+        python_function: pyspark.sql.connect.proto.expressions_pb2.PythonUDF | 
None = ...,
+        scala_function: 
pyspark.sql.connect.proto.expressions_pb2.ScalarScalaUDF | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "python_writer", b"python_writer", "scala_writer", 
b"scala_writer", "writer", b"writer"
+            "function",
+            b"function",
+            "python_function",
+            b"python_function",
+            "scala_function",
+            b"scala_function",
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "python_writer", b"python_writer", "scala_writer", 
b"scala_writer", "writer", b"writer"
+            "function",
+            b"function",
+            "python_function",
+            b"python_function",
+            "scala_function",
+            b"scala_function",
         ],
     ) -> None: ...
     def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["writer", b"writer"]
-    ) -> typing_extensions.Literal["python_writer", "scala_writer"] | None: ...
+        self, oneof_group: typing_extensions.Literal["function", b"function"]
+    ) -> typing_extensions.Literal["python_function", "scala_function"] | 
None: ...
 
-global___StreamingForeachWriter = StreamingForeachWriter
+global___StreamingForeachFunction = StreamingForeachFunction
 
 class WriteStreamOperationStartResult(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py 
b/python/pyspark/sql/connect/streaming/readwriter.py
index 563ce375d7a..156a3ba87db 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -487,10 +487,10 @@ class DataStreamWriter:
         serializer = AutoBatchedSerializer(CPickleSerializer())
         command = (func, None, serializer, serializer)
         # Python ForeachWriter isn't really a PythonUDF. But we reuse it for 
simplicity.
-        self._write_proto.foreach_writer.python_writer.command = 
CloudPickleSerializer().dumps(
+        self._write_proto.foreach_writer.python_function.command = 
CloudPickleSerializer().dumps(
             command
         )
-        self._write_proto.foreach_writer.python_writer.python_ver = "%d.%d" % 
sys.version_info[:2]
+        self._write_proto.foreach_writer.python_function.python_ver = "%d.%d" 
% sys.version_info[:2]
         return self
 
     foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to