This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ac39181fe87 [SPARK-47273][SS][PYTHON] implement Python data stream 
writer interface
5ac39181fe87 is described below

commit 5ac39181fe87aba4eab66ff2590bbc16349c0bab
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Wed Mar 27 12:51:36 2024 +0900

    [SPARK-47273][SS][PYTHON] implement Python data stream writer interface
    
    ### What changes were proposed in this pull request?
    Reuse PythonPartitionWriter to implement the serialization and execution of 
write callback in executor.
    
    Implement python worker process to run python streaming data sink committer 
and communicate with JVM through socket in spark driver. For each python 
streaming data sink instance there will be a long live python worker process 
created. Inside the python process, the python write committer will receive 
abort or commit function call and send back result through socket.
    
    ### Why are the changes needed?
    In order to support developing spark streaming sink in python, we need to 
implement python stream writer interface.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit and integration test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45305 from chaoqin-li1123/python_stream_writer.
    
    Authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/datasource.py                   |  94 ++++++++
 .../sql/worker/python_streaming_sink_runner.py     | 140 +++++++++++
 .../pyspark/sql/worker/write_into_data_source.py   |  10 +-
 .../python/PythonStreamingSinkCommitRunner.scala   | 133 +++++++++++
 .../v2/python/PythonStreamingWrite.scala           |  84 +++++++
 .../datasources/v2/python/PythonTable.scala        |   4 +-
 .../datasources/v2/python/PythonWrite.scala        |  12 +-
 .../v2/python/UserDefinedPythonDataSource.scala    |  11 +-
 .../spark/sql/streaming/DataStreamWriter.scala     |   5 +
 .../python/PythonStreamingDataSourceSuite.scala    | 261 ++++++++++++++++++++-
 10 files changed, 744 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 803765e83093..c08b5b7af77f 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -160,6 +160,29 @@ class DataSource(ABC):
             message_parameters={"feature": "writer"},
         )
 
+    def streamWriter(self, schema: StructType, overwrite: bool) -> 
"DataSourceStreamWriter":
+        """
+        Returns a :class:`DataSourceStreamWriter` instance for writing data 
into a streaming sink.
+
+        The implementation is required for writable streaming data sources.
+
+        Parameters
+        ----------
+        schema : :class:`StructType`
+            The schema of the data to be written.
+        overwrite : bool
+            A flag indicating whether to overwrite existing data when writing 
current microbatch.
+
+        Returns
+        -------
+        writer : :class:`DataSourceStreamWriter`
+            A writer instance for writing data into a streaming sink.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "streamWriter"},
+        )
+
     def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
         """
         Returns a :class:`DataSourceStreamReader` instance for reading 
streaming data.
@@ -513,6 +536,77 @@ class DataSourceWriter(ABC):
         ...
 
 
+class DataSourceStreamWriter(ABC):
+    """
+    A base class for data stream writers. Data stream writers are responsible 
for writing
+    the data to the streaming sink.
+
+    .. versionadded: 4.0.0
+    """
+
+    @abstractmethod
+    def write(self, iterator: Iterator[Row]) -> "WriterCommitMessage":
+        """
+        Writes data into the streaming sink.
+
+        This method is called on executors to write data to the streaming data 
sink in
+        each microbatch. It accepts an iterator of input data and returns a 
single row
+        representing a commit message, or None if there is no commit message.
+
+        The driver collects commit messages, if any, from all executors and 
passes them
+        to the ``commit`` method if all tasks run successfully. If any task 
fails, the
+        ``abort`` method will be called with the collected commit messages.
+
+        Parameters
+        ----------
+        iterator : Iterator[Row]
+            An iterator of input data.
+
+        Returns
+        -------
+        WriterCommitMessage : a serializable commit message
+        """
+        ...
+
+    def commit(self, messages: List["WriterCommitMessage"], batchId: int) -> 
None:
+        """
+        Commits this microbatch with a list of commit messages.
+
+        This method is invoked on the driver when all tasks run successfully. 
The
+        commit messages are collected from the ``write`` method call from each 
task,
+        and are passed to this method. The implementation should use the 
commit messages
+        to commit the microbatch in the streaming sink.
+
+        Parameters
+        ----------
+        messages : List[WriterCommitMessage]
+            A list of commit messages.
+        batchId: int
+            An integer that uniquely identifies a batch of data being written.
+            The integer increase by 1 with each microbatch processed.
+        """
+        ...
+
+    def abort(self, messages: List["WriterCommitMessage"], batchId: int) -> 
None:
+        """
+        Aborts this microbatch due to task failures.
+
+        This method is invoked on the driver when one or more tasks failed. 
The commit
+        messages are collected from the ``write`` method call from each task, 
and are
+        passed to this method. The implementation should use the commit 
messages to
+        abort the microbatch in the streaming sink.
+
+        Parameters
+        ----------
+        messages : List[WriterCommitMessage]
+            A list of commit messages.
+        batchId: int
+            An integer that uniquely identifies a batch of data being written.
+            The integer increase by 1 with each microbatch processed.
+        """
+        ...
+
+
 class WriterCommitMessage:
     """
     A commit message returned by the :meth:`DataSourceWriter.write` and will be
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
new file mode 100644
index 000000000000..d4f81da5aceb
--- /dev/null
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+from typing import IO
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    read_long,
+    write_int,
+    SpecialLengths,
+)
+from pyspark.sql.datasource import DataSource, WriterCommitMessage
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)
+        schema = _parse_datatype_json_string(schema_json)
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an output schema of type 'StructType'",
+                    "actual": f"'{type(schema).__name__}'",
+                },
+            )
+        # Receive the `overwrite` flag.
+        overwrite = read_bool(infile)
+        # Instantiate data source reader.
+        try:
+            writer = data_source.streamWriter(schema=schema, 
overwrite=overwrite)
+            # Initialization succeed.
+            write_int(0, outfile)
+            outfile.flush()
+
+            # handle method call from socket
+            while True:
+                num_messages = read_int(infile)
+                commit_messages = []
+                for _ in range(num_messages):
+                    message = pickleSer._read_with_length(infile)
+                    if message is not None and not isinstance(message, 
WriterCommitMessage):
+                        raise PySparkAssertionError(
+                            error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                            message_parameters={
+                                "expected": "an instance of 
WriterCommitMessage",
+                                "actual": f"'{type(message).__name__}'",
+                            },
+                        )
+                    commit_messages.append(message)
+                batch_id = read_long(infile)
+                abort = read_bool(infile)
+                # Commit or abort the Python data source write.
+                # Note the commit messages can be None if there are failed 
tasks.
+                if abort:
+                    writer.abort(commit_messages, batch_id)  # type: 
ignore[arg-type]
+                else:
+                    writer.commit(commit_messages, batch_id)  # type: 
ignore[arg-type]
+                write_int(0, outfile)
+                outfile.flush()
+        except Exception as e:
+            error_msg = "data source {} throw exception: 
{}".format(data_source.name, e)
+            raise PySparkRuntimeError(
+                error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+                message_parameters={"action": "commitOrAbort", "error": 
error_msg},
+            )
+    except BaseException as e:
+        handle_worker_exception(e, outfile)
+        sys.exit(-1)
+    send_accumulator_updates(outfile)
+
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 0ba6fc6eb17f..490ede9ab0f2 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -152,11 +152,17 @@ def main(infile: IO, outfile: IO) -> None:
         # Receive the `overwrite` flag.
         overwrite = read_bool(infile)
 
+        is_streaming = read_bool(infile)
+
         # Instantiate a data source.
         data_source = data_source_cls(options=options)  # type: ignore
 
-        # Instantiate the data source writer.
-        writer = data_source.writer(schema, overwrite)
+        if is_streaming:
+            # Instantiate the streaming data source writer.
+            writer = data_source.streamWriter(schema, overwrite)
+        else:
+            # Instantiate the data source writer.
+            writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
 
         # Create a function that can be used in mapInArrow.
         import pyarrow as pa
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
new file mode 100644
index 000000000000..a444fdfff7d9
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.sql.connector.write.WriterCommitMessage
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This class is a proxy to invoke commit or abort methods in Python 
DataSourceStreamWriter.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamWriter instance.
+ * In an infinite loop, the python worker process receive write commit messages
+ * from the socket, then commit or abort a microbatch.
+ */
+class PythonStreamingSinkCommitRunner(
+    func: PythonFunction,
+    schema: StructType,
+    overwrite: Boolean) extends Logging {
+  val workerModule: String = "pyspark.sql.worker.python_streaming_sink_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  /**
+   * Initializes the Python worker for running the streaming sink committer.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val workerFactory =
+      new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, 
false)
+    val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+    pythonWorker = Some(worker)
+    pythonWorkerFactory = Some(workerFactory)
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    PythonWorkerUtils.writeUTF(schema.json, dataOut)
+
+    dataOut.writeBoolean(overwrite)
+
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "initialize streaming sink", msg)
+    }
+  }
+
+  init()
+
+  def commitOrAbort(
+      messages: Array[WriterCommitMessage],
+      batchId: Long,
+      abort: Boolean): Unit = {
+    dataOut.writeInt(messages.length)
+    messages.foreach { message =>
+      // Commit messages can be null if there are task failures.
+      if (message == null) {
+        dataOut.writeInt(SpecialLengths.NULL)
+      } else {
+        PythonWorkerUtils.writeBytes(
+          message.asInstanceOf[PythonWriterCommitMessage].pickledMessage, 
dataOut)
+      }
+    }
+    dataOut.writeLong(batchId)
+    dataOut.writeBoolean(abort)
+    dataOut.flush()
+    val status = dataIn.readInt()
+    if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      val action = if (abort) "abort" else "commit"
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(action, 
msg)
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
new file mode 100644
index 000000000000..483fd5a4e0a1
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write._
+import 
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, 
StreamingWrite}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[streamingWrite]] for python data source writing.
+ * Responsible for generating the writer factory, committing or aborting a 
microbatch.
+ * */
+class PythonStreamingWrite(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    info: LogicalWriteInfo,
+    isTruncate: Boolean) extends StreamingWrite {
+
+  // Store the pickled data source writer instance.
+  private var pythonDataSourceWriter: Array[Byte] = _
+
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  private def createDataSourceFunc =
+    ds.source.createPythonFunction(
+      ds.getOrCreateDataSourceInPython(shortName, info.options(), 
Some(info.schema())).dataSource
+    )
+
+  private lazy val pythonStreamingSinkCommitRunner =
+    new PythonStreamingSinkCommitRunner(createDataSourceFunc, info.schema(), 
isTruncate)
+
+  override def createStreamingWriterFactory(
+       physicalInfo: PhysicalWriteInfo): StreamingDataWriterFactory = {
+    val writeInfo = ds.source.createWriteInfoInPython(
+      shortName,
+      info.schema(),
+      info.options(),
+      isTruncate,
+      isStreaming = true)
+
+    pythonDataSourceWriter = writeInfo.writer
+
+    new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(), 
jobArtifactUUID)
+  }
+
+  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
+    pythonStreamingSinkCommitRunner.commitOrAbort(messages, epochId, false)
+  }
+
+  override def abort(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
+    pythonStreamingSinkCommitRunner.commitOrAbort(messages, epochId, true)
+  }
+}
+
+class PythonStreamingWriterFactory(
+    source: UserDefinedPythonDataSource,
+    pickledWriteFunc: Array[Byte],
+    inputSchema: StructType,
+    jobArtifactUUID: Option[String])
+  extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema, 
jobArtifactUUID)
+    with StreamingDataWriterFactory {
+  override def createWriter(
+      partitionId: Int,
+      taskId: Long,
+      epochId: Long): DataWriter[InternalRow] = {
+    createWriter(partitionId, taskId)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
index 6bea97795a35..0476650a60bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability}
-import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, TRUNCATE}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, STREAMING_WRITE, TRUNCATE}
 import org.apache.spark.sql.connector.read.ScanBuilder
 import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
 import org.apache.spark.sql.types.StructType
@@ -32,7 +32,7 @@ class PythonTable(
   override def name(): String = shortName
 
   override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
-    BATCH_READ, BATCH_WRITE, TRUNCATE)
+    BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE)
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
     new PythonScanBuilder(ds, shortName, outputSchema, options)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
index a10a18e43f64..447221715264 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
@@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.JobArtifactSet
 import org.apache.spark.sql.connector.metric.CustomMetric
-import org.apache.spark.sql.connector.write._
+import org.apache.spark.sql.connector.write.{BatchWrite, _}
+import org.apache.spark.sql.connector.write.streaming.StreamingWrite
 
 
 class PythonWrite(
@@ -32,12 +33,18 @@ class PythonWrite(
 
   override def toBatch: BatchWrite = new PythonBatchWrite(ds, shortName, info, 
isTruncate)
 
+  override def toStreaming: StreamingWrite =
+    new PythonStreamingWrite(ds, shortName, info, isTruncate)
+
   override def description: String = "(Python)"
 
   override def supportedCustomMetrics(): Array[CustomMetric] =
     ds.source.createPythonMetrics()
 }
 
+/**
+ * A [[BatchWrite]] for python data source writing. Responsible for generating 
the writer factory.
+ * */
 class PythonBatchWrite(
     ds: PythonDataSourceV2,
     shortName: String,
@@ -56,7 +63,8 @@ class PythonBatchWrite(
       shortName,
       info.schema(),
       info.options(),
-      isTruncate)
+      isTruncate,
+      isStreaming = false)
 
     pythonDataSourceWriter = writeInfo.writer
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 9b8219c4dc2d..0586d1fd4bc1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -81,13 +81,15 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       provider: String,
       inputSchema: StructType,
       options: CaseInsensitiveStringMap,
-      overwrite: Boolean): PythonDataSourceWriteInfo = {
+      overwrite: Boolean,
+      isStreaming: Boolean): PythonDataSourceWriteInfo = {
     new UserDefinedPythonDataSourceWriteRunner(
       dataSourceCls,
       provider,
       inputSchema,
       options.asCaseSensitiveMap().asScala.toMap,
-      overwrite).runInPython()
+      overwrite,
+      isStreaming).runInPython()
   }
 
   /**
@@ -369,7 +371,8 @@ private class UserDefinedPythonDataSourceWriteRunner(
     provider: String,
     inputSchema: StructType,
     options: Map[String, String],
-    overwrite: Boolean) extends 
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
+    overwrite: Boolean,
+    isStreaming: Boolean) extends 
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
 
   override val workerModule: String = 
"pyspark.sql.worker.write_into_data_source"
 
@@ -395,6 +398,8 @@ private class UserDefinedPythonDataSourceWriteRunner(
 
     // Send the `overwrite` flag
     dataOut.writeBoolean(overwrite)
+
+    dataOut.writeBoolean(isStreaming)
   }
 
   override protected def receiveFromPython(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index be25cd0dda3b..1db03c5d816f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, 
FileDataSourceV2}
+import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources._
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -395,6 +396,10 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
         } else {
           None
         }
+        provider match {
+          case p: PythonDataSourceV2 => p.setShortName(source)
+          case _ =>
+        }
         val table = DataSourceV2Utils.getTableFromProvider(
           provider, dsOptions, userSpecifiedSchema = outputSchema)
         import 
org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
index f022e353edd7..42eaa492be73 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -16,15 +16,24 @@
  */
 package org.apache.spark.sql.execution.python
 
+import java.io.File
+
+import scala.concurrent.duration._
+
 import org.apache.spark.SparkException
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
 import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
 
+  import testImplicits._
+
+  val waitTimeout = 15.seconds
+
   protected def simpleDataStreamReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
@@ -65,6 +74,40 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
       |        yield (2, partition.value)
       |""".stripMargin
 
+  protected def simpleDataStreamWriterScript: String =
+    s"""
+       |import json
+       |import uuid
+       |import os
+       |from pyspark import TaskContext
+       |from pyspark.sql.datasource import DataSource, DataSourceStreamWriter
+       |from pyspark.sql.datasource import WriterCommitMessage
+       |
+       |class SimpleDataSourceStreamWriter(DataSourceStreamWriter):
+       |    def __init__(self, options, overwrite):
+       |        self.options = options
+       |        self.overwrite = overwrite
+       |
+       |    def write(self, iterator):
+       |        context = TaskContext.get()
+       |        partition_id = context.partitionId()
+       |        path = self.options.get("path")
+       |        assert path is not None
+       |        output_path = os.path.join(path, f"{partition_id}.json")
+       |        cnt = 0
+       |        mode = "w" if self.overwrite else "a"
+       |        with open(output_path, mode) as file:
+       |            for row in iterator:
+       |                file.write(json.dumps(row.asDict()) + "\\n")
+       |        return WriterCommitMessage()
+       |
+       |class SimpleDataSource(DataSource):
+       |    def schema(self) -> str:
+       |        return "id INT"
+       |    def streamWriter(self, schema, overwrite):
+       |        return SimpleDataSourceStreamWriter(self.options, overwrite)
+       |""".stripMargin
+
   private val errorDataSourceName = "ErrorDataSource"
 
   test("simple data stream source") {
@@ -230,4 +273,220 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
       stream => stream.commit(offset)
     }
   }
+
+  Seq("append", "complete").foreach { mode =>
+    test(s"data source stream write - $mode mode") {
+      assume(shouldTestPandasUDFs)
+      val dataSource =
+        createUserDefinedPythonDataSource(dataSourceName, 
simpleDataStreamWriterScript)
+      spark.dataSource.registerPython(dataSourceName, dataSource)
+      val inputData = MemoryStream[Int]
+      withTempDir { dir =>
+        val path = dir.getAbsolutePath
+        val checkpointDir = new File(path, "checkpoint")
+        checkpointDir.mkdir()
+        val outputDir = new File(path, "output")
+        outputDir.mkdir()
+        val streamDF = if (mode == "append") {
+          inputData.toDF()
+        } else {
+          // Complete mode only supports stateful aggregation
+          inputData.toDF()
+            .groupBy("value").count()
+        }
+        def resultDf: DataFrame = spark.read.format("json")
+          .load(outputDir.getAbsolutePath)
+        val q = streamDF
+          .writeStream
+          .format(dataSourceName)
+          .outputMode(mode)
+          .option("checkpointLocation", checkpointDir.getAbsolutePath)
+          .start(outputDir.getAbsolutePath)
+
+        inputData.addData(1, 2, 3)
+        eventually(timeout(waitTimeout)) {
+          if (mode == "append") {
+            checkAnswer(
+              resultDf,
+              Seq(Row(1), Row(2), Row(3)))
+          } else {
+            checkAnswer(
+              resultDf.select("value", "count"),
+              Seq(Row(1, 1), Row(2, 1), Row(3, 1)))
+          }
+        }
+
+        inputData.addData(1, 4)
+        eventually(timeout(waitTimeout)) {
+          if (mode == "append") {
+            checkAnswer(
+              resultDf,
+              Seq(Row(1), Row(2), Row(3), Row(4), Row(1)))
+          } else {
+            checkAnswer(
+              resultDf.select("value", "count"),
+              Seq(Row(1, 2), Row(2, 1), Row(3, 1), Row(4, 1)))
+          }
+        }
+
+        q.stop()
+        q.awaitTermination()
+        assert(q.exception.isEmpty)
+      }
+    }
+  }
+
+  test("streaming sink write commit and abort") {
+    assume(shouldTestPandasUDFs)
+    // The data source write the number of rows and partitions into 
batchId.json in
+    // the output directory in commit() function. If aborting a microbatch, it 
writes
+    // batchId.txt into output directory.
+    val dataSourceScript =
+      s"""
+         |import json
+         |import os
+         |from dataclasses import dataclass
+         |from pyspark import TaskContext
+         |from pyspark.sql.datasource import DataSource, 
DataSourceStreamWriter, WriterCommitMessage
+         |
+         |@dataclass
+         |class SimpleCommitMessage(WriterCommitMessage):
+         |    partition_id: int
+         |    count: int
+         |
+         |class SimpleDataSourceStreamWriter(DataSourceStreamWriter):
+         |    def __init__(self, options):
+         |        self.options = options
+         |        self.path = self.options.get("path")
+         |        assert self.path is not None
+         |
+         |    def write(self, iterator):
+         |        context = TaskContext.get()
+         |        partition_id = context.partitionId()
+         |        cnt = 0
+         |        for row in iterator:
+         |            if row.value > 50:
+         |                raise Exception("invalid value")
+         |            cnt += 1
+         |        return SimpleCommitMessage(partition_id=partition_id, 
count=cnt)
+         |
+         |    def commit(self, messages, batchId) -> None:
+         |        status = dict(num_partitions=len(messages), rows=sum(m.count 
for m in messages))
+         |
+         |        with open(os.path.join(self.path, f"{batchId}.json"), "a") 
as file:
+         |            file.write(json.dumps(status) + "\\n")
+         |
+         |    def abort(self, messages, batchId) -> None:
+         |        with open(os.path.join(self.path, f"{batchId}.txt"), "w") as 
file:
+         |            file.write(f"failed in batch {batchId}")
+         |
+         |class SimpleDataSource(DataSource):
+         |    def streamWriter(self, schema, overwrite):
+         |        return SimpleDataSourceStreamWriter(self.options)
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val inputData = MemoryStream[Int](numPartitions = 3)
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      checkpointDir.mkdir()
+      val outputDir = new File(path, "output")
+      outputDir.mkdir()
+      val q = inputData.toDF()
+        .writeStream
+        .format(dataSourceName)
+        .outputMode("append")
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .start(outputDir.getAbsolutePath)
+
+      def metadataDf: DataFrame = spark.read.format("json")
+        .load(outputDir.getAbsolutePath)
+
+      // Batch 0-2 should succeed and json commit files are written.
+      inputData.addData(1 to 30)
+      eventually(timeout(waitTimeout)) {
+        checkAnswer(metadataDf, Seq(Row(3, 30)))
+      }
+
+      inputData.addData(31 to 50)
+      eventually(timeout(waitTimeout)) {
+        checkAnswer(metadataDf, Seq(Row(3, 30), Row(3, 20)))
+      }
+
+      // Write and commit an empty batch.
+      inputData.addData(Seq.empty)
+      eventually(timeout(waitTimeout)) {
+        checkAnswer(metadataDf, Seq(Row(3, 30), Row(3, 20), Row(3, 0)))
+      }
+
+      // The sink throws exception when encountering value > 50 in batch 3.
+      // The streamWriter will write error message in 3.txt during abort().
+      inputData.addData(51 to 100)
+      eventually(timeout(waitTimeout)) {
+        checkAnswer(
+          spark.read.text(outputDir.getAbsolutePath + "/3.txt"),
+          Seq(Row("failed in batch 3")))
+      }
+
+      q.stop()
+      assert(q.exception.get.message.contains("invalid value"))
+    }
+  }
+
+  test("python streaming sink: invalid write mode") {
+    assume(shouldTestPandasUDFs)
+    // The data source write the number of rows and partitions into 
batchId.json in
+    // the output directory in commit() function. If aborting a microbatch, it 
writes
+    // batchId.txt into output directory.
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
simpleDataStreamWriterScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      checkpointDir.mkdir()
+      val outputDir = new File(path, "output")
+      outputDir.mkdir()
+
+      def runQuery(mode: String): Unit = {
+        val inputData = MemoryStream[Int]
+        withTempDir { dir =>
+          val path = dir.getAbsolutePath
+          val checkpointDir = new File(path, "checkpoint")
+          checkpointDir.mkdir()
+          val outputDir = new File(path, "output")
+          outputDir.mkdir()
+          val q = inputData.toDF()
+            .writeStream
+            .format(dataSourceName)
+            .outputMode(mode)
+            .option("checkpointLocation", checkpointDir.getAbsolutePath)
+            .start(outputDir.getAbsolutePath)
+          q.stop()
+          q.awaitTermination()
+        }
+      }
+
+      runQuery("append")
+      runQuery("update")
+
+      // Complete mode is not supported for stateless query.
+      checkError(
+        exception = intercept[AnalysisException] {
+          runQuery("complete")
+        },
+        errorClass = "_LEGACY_ERROR_TEMP_3102",
+        parameters = Map(
+          "msg" -> ("Complete output mode not supported when there are no 
streaming aggregations" +
+            " on streaming DataFrames/Datasets")))
+
+      // Query should fail in planning with "invalid" mode.
+      val error2 = intercept[IllegalArgumentException] {
+        runQuery("invalid")
+      }
+      assert(error2.getMessage.contains("invalid"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to