This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f53dc08230b7 [SPARK-46962][SS][PYTHON] Add interface for python 
streaming data source API and implement python worker to run python streaming 
data source
f53dc08230b7 is described below

commit f53dc08230b7a758227f02fd75d2b446721c139f
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Tue Mar 12 12:35:18 2024 +0900

    [SPARK-46962][SS][PYTHON] Add interface for python streaming data source 
API and implement python worker to run python streaming data source
    
    ### What changes were proposed in this pull request?
    This is the first PR the implement the support to implement streaming data 
source through python API.
    Implement python worker to run python streaming data source and communicate 
with JVM through socket. Create a PythonMicrobatchStream to invoke RPC function 
call. This happens in the spark driver. For each python streaming data source 
instance there will be a long live python worker process created. Inside the 
python process, the python streaming reader will receive function call and 
parameter from JVM PythonMicroBatchStream and send back result through socket.
    
    ### Why are the changes needed?
    In preparation for support of development of streaming data source in 
Python.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit test. PythonMicroBatchStream plan offset and partitions by invoking 
function call through socket correctly and handle error correctly.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45023 from chaoqin-li1123/python_table.
    
    Lead-authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Co-authored-by: chaoqin-li1123 
<55518381+chaoqin-li1...@users.noreply.github.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |   6 +
 docs/sql-error-conditions.md                       |   6 +
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/sql/datasource.py                   | 153 ++++++++++++++
 .../streaming/python_streaming_source_runner.py    | 167 +++++++++++++++
 .../spark/sql/errors/QueryExecutionErrors.scala    |   9 +
 .../v2/python/PythonMicroBatchStream.scala         |  68 ++++++
 .../datasources/v2/python/PythonScan.scala         |  32 ++-
 .../v2/python/UserDefinedPythonDataSource.scala    |   2 +-
 .../python/PythonStreamingSourceRunner.scala       | 202 ++++++++++++++++++
 .../execution/python/PythonDataSourceSuite.scala   |  20 +-
 .../python/PythonStreamingDataSourceSuite.scala    | 233 +++++++++++++++++++++
 12 files changed, 883 insertions(+), 20 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index afe81b8e9bea..3d130fdce301 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3197,6 +3197,12 @@
     ],
     "sqlState" : "38000"
   },
+  "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : {
+    "message" : [
+      "Failed when Python streaming data source perform <action>: <msg>"
+    ],
+    "sqlState" : "38000"
+  },
   "RECURSIVE_PROTOBUF_SCHEMA" : {
     "message" : [
       "Found recursive reference in Protobuf schema, which can not be 
processed by Spark by default: <fieldDescriptor>. try setting the option 
`recursive.fields.max.depth` 0 to 10. Going beyond 10 levels of recursion is 
not allowed."
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 0695ed28b7fc..2cddb6a94c14 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -1931,6 +1931,12 @@ Protobuf type not yet supported: `<protobufType>`.
 
 Failed to `<action>` Python data source `<type>`: `<msg>`
 
+### PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR
+
+[SQLSTATE: 
38000](sql-error-conditions-sqlstates.html#class-38-external-routine-exception)
+
+Failed when Python streaming data source perform `<action>`: `<msg>`
+
 ### RECURSIVE_PROTOBUF_SCHEMA
 
 [SQLSTATE: 
42K0G](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index c9a7cfbf356e..1e21ad3543e9 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -812,6 +812,11 @@ ERROR_CLASSES_JSON = '''
       "Randomness of hash of string should be disabled via PYTHONHASHSEED."
     ]
   },
+  "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR": {
+    "message": [
+      "Failed when running Python streaming data source: <msg>"
+    ]
+  },
   "PYTHON_VERSION_MISMATCH": {
     "message": [
       "Python in worker has different version: <worker_version> than that in 
driver: <driver_version>, PySpark cannot run with different minor versions.",
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 1e50f8270243..b98729f04332 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 __all__ = [
     "DataSource",
     "DataSourceReader",
+    "DataSourceStreamReader",
     "DataSourceWriter",
     "DataSourceRegistration",
     "InputPartition",
@@ -159,6 +160,27 @@ class DataSource(ABC):
             message_parameters={"feature": "writer"},
         )
 
+    def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
+        """
+        Returns a ``DataSourceStreamReader`` instance for reading streaming 
data.
+
+        The implementation is required for readable streaming data sources.
+
+        Parameters
+        ----------
+        schema : StructType
+            The schema of the data to be read.
+
+        Returns
+        -------
+        reader : DataSourceStreamReader
+            A reader instance for this streaming data source.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "streamReader"},
+        )
+
 
 class InputPartition:
     """
@@ -298,6 +320,137 @@ class DataSourceReader(ABC):
         ...
 
 
+class DataSourceStreamReader(ABC):
+    """
+    A base class for streaming data source readers. Data source stream readers 
are responsible
+    for outputting data from a streaming data source.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+
+        ...
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def latestOffset(self) -> dict:
+        """
+        Returns the most recent offset available.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def latestOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+        ...
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "latestOffset"},
+        )
+
+    def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
+        """
+        Returns a list of InputPartition  given the start and end offsets. 
Each InputPartition
+        represents a data split that can be processed by one Spark task.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset of the microbatch to plan partitioning.
+        end : dict
+            The end offset of the microbatch to plan partitioning.
+
+        Returns
+        -------
+        Sequence[InputPartition]
+            A sequence of partitions for this data source. Each partition value
+            must be an instance of `InputPartition` or a subclass of it.
+        """
+        ...
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "partitions"},
+        )
+
+    @abstractmethod
+    def read(self, partition: InputPartition) -> Iterator[Union[Tuple, Row]]:
+        """
+        Generates data for a given partition and returns an iterator of tuples 
or rows.
+
+        This method is invoked once per partition to read the data. 
Implementing
+        this method is required for stream reader. You can initialize any
+        non-serializable resources required for reading data from the data 
source
+        within this method.
+
+        Notes
+        -----
+        This method is static and stateless. You shouldn't access mutable 
class member
+        or keep in memory state between different invocations of read().
+
+        Parameters
+        ----------
+        partition : InputPartition
+            The partition to read. It must be one of the partition values 
returned by
+            ``partitions()``.
+
+        Returns
+        -------
+        Iterator[Tuple] or Iterator[Row]
+            An iterator of tuples or rows. Each tuple or row will be converted 
to a row
+            in the final DataFrame.
+        """
+        ...
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read"},
+        )
+
+    def commit(self, end: dict) -> None:
+        """
+        Informs the source that Spark has completed processing all data for 
offsets less than or
+        equal to `end` and will only request offsets greater than `end` in the 
future.
+
+        Parameters
+        ----------
+        end : dict
+            The latest offset that the streaming query has processed for this 
source.
+        """
+        ...
+
+    def stop(self) -> None:
+        """
+        Stop this source and free any resources it has allocated.
+        Invoked when the streaming query terminated.
+        """
+        ...
+
+
 class DataSourceWriter(ABC):
     """
     A base class for data source writers. Data source writers are responsible 
for saving
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
new file mode 100644
index 000000000000..8dbac431a8ba
--- /dev/null
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -0,0 +1,167 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import json
+from typing import IO
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import IllegalArgumentException, PySparkAssertionError, 
PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.datasource import DataSource, DataSourceStreamReader
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+INITIAL_OFFSET_FUNC_ID = 884
+LATEST_OFFSET_FUNC_ID = 885
+PARTITIONS_FUNC_ID = 886
+COMMIT_FUNC_ID = 887
+
+
+def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
+    offset = reader.initialOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
+    offset = reader.latestOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def partitions_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) 
-> None:
+    start_offset = json.loads(utf8_deserializer.loads(infile))
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    partitions = reader.partitions(start_offset, end_offset)
+    # Return the serialized partition values.
+    write_int(len(partitions), outfile)
+    for partition in partitions:
+        pickleSer._write_with_length(partition, outfile)
+
+
+def commit_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> 
None:
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    reader.commit(end_offset)
+    write_int(0, outfile)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)
+        schema = _parse_datatype_json_string(schema_json)
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an output schema of type 'StructType'",
+                    "actual": f"'{type(schema).__name__}'",
+                },
+            )
+
+        # Instantiate data source reader.
+        try:
+            reader = data_source.streamReader(schema=schema)
+            # Initialization succeed.
+            write_int(0, outfile)
+            outfile.flush()
+
+            # handle method call from socket
+            while True:
+                func_id = read_int(infile)
+                if func_id == INITIAL_OFFSET_FUNC_ID:
+                    initial_offset_func(reader, outfile)
+                elif func_id == LATEST_OFFSET_FUNC_ID:
+                    latest_offset_func(reader, outfile)
+                elif func_id == PARTITIONS_FUNC_ID:
+                    partitions_func(reader, infile, outfile)
+                elif func_id == COMMIT_FUNC_ID:
+                    commit_func(reader, infile, outfile)
+                else:
+                    raise IllegalArgumentException(
+                        error_class="UNSUPPORTED_OPERATION",
+                        message_parameters={
+                            "operation": "Function call id not recognized by 
stream reader"
+                        },
+                    )
+                outfile.flush()
+        except Exception as e:
+            error_msg = "data source {} throw exception: 
{}".format(data_source.name, e)
+            raise PySparkRuntimeError(
+                error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+                message_parameters={"error": error_msg},
+            )
+        finally:
+            reader.stop()
+    except BaseException as e:
+        handle_worker_exception(e, outfile)
+        sys.exit(-1)
+    send_accumulator_updates(outfile)
+
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    main(sock_file, sock_file)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 8280826fafcb..32482015b3ed 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1758,6 +1758,15 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
         "outputMode" -> outputMode.toString()))
   }
 
+  def pythonStreamingDataSourceRuntimeError(
+      action: String,
+      message: String): SparkException = {
+    new SparkException(
+      errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+      messageParameters = Map("action" -> action, "msg" -> message),
+      cause = null)
+  }
+
   def invalidCatalogNameError(name: String): Throwable = {
     new SparkException(
       errorClass = "_LEGACY_ERROR_TEMP_2212",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
new file mode 100644
index 000000000000..4e77f33c24f0
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+case class PythonStreamingSourceOffset(json: String) extends Offset
+
+case class PythonStreamingSourcePartition(partition: Array[Byte]) extends 
InputPartition
+
+class PythonMicroBatchStream(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap
+  ) extends MicroBatchStream with Logging {
+  private def createDataSourceFunc =
+    ds.source.createPythonFunction(
+      ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
+
+  private val runner: PythonStreamingSourceRunner =
+    new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
+  runner.init()
+
+  override def initialOffset(): Offset = 
PythonStreamingSourceOffset(runner.initialOffset())
+
+  override def latestOffset(): Offset = 
PythonStreamingSourceOffset(runner.latestOffset())
+
+  override def planInputPartitions(start: Offset, end: Offset): 
Array[InputPartition] = {
+    runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
+      
end.asInstanceOf[PythonStreamingSourceOffset].json).map(PythonStreamingSourcePartition(_))
+  }
+
+  override def createReaderFactory(): PartitionReaderFactory = {
+    // TODO(SPARK-47107): fill in the implementation.
+    null
+  }
+
+  override def commit(end: Offset): Unit = {
+    runner.commit(end.asInstanceOf[PythonStreamingSourceOffset].json)
+  }
+
+  override def stop(): Unit = {
+    runner.stop()
+  }
+
+  override def deserializeOffset(json: String): Offset = 
PythonStreamingSourceOffset(json)
+}
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 75cbe38b1397..bcddf66fc161 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.python
 import org.apache.spark.JobArtifactSet
 import org.apache.spark.sql.connector.metric.CustomMetric
 import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
@@ -27,9 +28,27 @@ class PythonScan(
      ds: PythonDataSourceV2,
      shortName: String,
      outputSchema: StructType,
-     options: CaseInsensitiveStringMap) extends Batch with Scan {
+     options: CaseInsensitiveStringMap) extends Scan {
 
-  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, 
options)
+
+  override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream =
+    new PythonMicroBatchStream(ds, shortName, outputSchema, options)
+
+  override def description: String = "(Python)"
+
+  override def readSchema(): StructType = outputSchema
+
+  override def supportedCustomMetrics(): Array[CustomMetric] =
+    ds.source.createPythonMetrics()
+}
+
+class PythonBatch(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap) extends Batch {
+  private val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
   private lazy val infoInPython: PythonDataSourceReadInfo = {
     ds.source.createReadInfoInPython(
@@ -45,13 +64,4 @@ class PythonScan(
     new PythonPartitionReaderFactory(
       ds.source, readerFunc, outputSchema, jobArtifactUUID)
   }
-
-  override def toBatch: Batch = this
-
-  override def description: String = "(Python)"
-
-  override def readSchema(): StructType = outputSchema
-
-  override def supportedCustomMetrics(): Array[CustomMetric] =
-    ds.source.createPythonMetrics()
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 0e5f359ee76f..9b8219c4dc2d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -150,7 +150,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
     taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray
   }
 
-  private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = 
{
+  def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = {
     SimplePythonFunction(
       command = pickledFunc.toImmutableArraySeq,
       envVars = dataSourceCls.envVars,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
new file mode 100644
index 000000000000..4862f8220f06
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  // When the python process for python_streaming_source_runner receives one 
of the
+  // integers below, it will invoke the corresponding function of StreamReader 
instance.
+  val INITIAL_OFFSET_FUNC_ID = 884
+  val LATEST_OFFSET_FUNC_ID = 885
+  val PARTITIONS_FUNC_ID = 886
+  val COMMIT_FUNC_ID = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  private val bufferSize: Int = conf.get(BUFFER_SIZE)
+  private val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  private val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val workerFactory =
+      new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
+    val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+    pythonWorker = Some(worker)
+    pythonWorkerFactory = Some(workerFactory)
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "initialize source", msg = msg)
+    }
+  }
+
+  /**
+   * Invokes latestOffset() function of the stream reader and receive the 
return value.
+   */
+  def latestOffset(): String = {
+    dataOut.writeInt(LATEST_OFFSET_FUNC_ID)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "latestOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes initialOffset() function of the stream reader and receive the 
return value.
+   */
+  def initialOffset(): String = {
+    dataOut.writeInt(INITIAL_OFFSET_FUNC_ID)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "initialOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes partitions(start, end) function of the stream reader and receive 
the return value.
+   */
+  def partitions(start: String, end: String): Array[Array[Byte]] = {
+    dataOut.writeInt(PARTITIONS_FUNC_ID)
+    PythonWorkerUtils.writeUTF(start, dataOut)
+    PythonWorkerUtils.writeUTF(end, dataOut)
+    dataOut.flush()
+    // Receive the list of partitions, if any.
+    val pickledPartitions = ArrayBuffer.empty[Array[Byte]]
+    val numPartitions = dataIn.readInt()
+    if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "planPartitions", msg)
+    }
+    for (_ <- 0 until numPartitions) {
+      val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
+      pickledPartitions.append(pickledPartition)
+    }
+    pickledPartitions.toArray
+  }
+
+  /**
+   * Invokes commit(end) function of the stream reader and receive the return 
value.
+   */
+  def commit(end: String): Unit = {
+    dataOut.writeInt(COMMIT_FUNC_ID)
+    PythonWorkerUtils.writeUTF(end, dataOut)
+    dataOut.flush()
+    val status = dataIn.readInt()
+    if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "commitSource", msg)
+    }
+  }
+
+  /**
+   * Stop the python worker process and invoke stop() on stream reader.
+   */
+  def stop(): Unit = {
+    logInfo(s"Stopping streaming runner for module: $workerModule.")
+    try {
+      pythonWorkerFactory.foreach { factory =>
+        pythonWorker.foreach { worker =>
+          factory.stopWorker(worker)
+          factory.stop()
+        }
+      }
+    } catch {
+      case e: Exception =>
+        logError("Exception when trying to kill worker", e)
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index c2f5864b7345..b207afeae106 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -28,13 +28,9 @@ import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
-class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
-  import IntegratedUDFTestUtils._
-
-  setupTestData()
+abstract class PythonDataSourceSuiteBase extends QueryTest with 
SharedSparkSession {
 
-  private def dataSourceName = "SimpleDataSource"
-  private val simpleDataSourceReaderScript: String =
+  protected val simpleDataSourceReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceReader, InputPartition
       |class SimpleDataSourceReader(DataSourceReader):
@@ -45,8 +41,8 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
       |        yield (1, partition.value)
       |        yield (2, partition.value)
       |""".stripMargin
-  private val staticSourceName = "custom_source"
-  private var tempDir: File = _
+  protected val staticSourceName = "custom_source"
+  protected var tempDir: File = _
 
   override def beforeAll(): Unit = {
     // Create a Python Data Source package before starting up the Spark Session
@@ -90,6 +86,14 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
+  setupTestData()
+
+  protected def dataSourceName = "SimpleDataSource"
+}
+
+class PythonDataSourceSuite extends PythonDataSourceSuiteBase {
+  import IntegratedUDFTestUtils._
+
   test("SPARK-45917: automatic registration of Python Data Source") {
     assume(shouldTestPandasUDFs)
     val df = spark.read.format(staticSourceName).load()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
new file mode 100644
index 000000000000..f022e353edd7
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
+
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"offset": {"partition-1": 0}}
+      |    def latestOffset(self):
+      |        return {"offset": {"partition-1": 2}}
+      |    def partitions(self, start: dict, end: dict):
+      |        start_index = start["offset"]["partition-1"]
+      |        end_index = end["offset"]["partition-1"]
+      |        return [InputPartition(i) for i in range(start_index, 
end_index)]
+      |    def commit(self, end: dict):
+      |        1 + 2
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  protected def errorDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class ErrorDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        raise Exception("error reading initial offset")
+      |    def latestOffset(self):
+      |        raise Exception("error reading latest offset")
+      |    def partitions(self, start: dict, end: dict):
+      |        raise Exception("error planning partitions")
+      |    def commit(self, end: dict):
+      |        raise Exception("error committing offset")
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  private val errorDataSourceName = "ErrorDataSource"
+
+  test("simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val stream = new PythonMicroBatchStream(
+      pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
+
+    val initialOffset = stream.initialOffset()
+    assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}")
+    for (_ <- 1 to 50) {
+      val offset = stream.latestOffset()
+      assert(offset.json == "{\"offset\": {\"partition-1\": 2}}")
+      assert(stream.planInputPartitions(initialOffset, offset).size == 2)
+      stream.commit(offset)
+    }
+    stream.stop()
+  }
+
+  test("Error creating stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        raise Exception("error creating stream reader")
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(
+      name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val inputSchema = StructType.fromDDL("input BINARY")
+    val err = intercept[AnalysisException] {
+      new PythonMicroBatchStream(
+        pythonDs, dataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+    }
+    assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR")
+    assert(err.getMessage.contains("error creating stream reader"))
+  }
+
+  test("Method not implemented in stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import DataSourceStreamReader
+         |class ErrorDataStreamReader(DataSourceStreamReader):
+         |    def read(self, partition):
+         |        yield (0, partition.value)
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit 
= {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError(
+      "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") {
+      stream => stream.initialOffset()
+    }
+
+    testMicroBatchStreamError(
+      "latestOffset", "[NOT_IMPLEMENTED] latestOffset is not implemented") {
+      stream => stream.latestOffset()
+    }
+
+    val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
+    testMicroBatchStreamError(
+      "planPartitions", "[NOT_IMPLEMENTED] partitions is not implemented") {
+      stream => stream.planInputPartitions(offset, offset)
+    }
+  }
+
+  test("Error in stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$errorDataStreamReaderScript
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+    val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit 
= {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError("initialOffset", "error reading initial offset") 
{
+      stream => stream.initialOffset()
+    }
+
+    testMicroBatchStreamError("latestOffset", "error reading latest offset") {
+      stream => stream.latestOffset()
+    }
+
+    testMicroBatchStreamError("planPartitions", "error planning partitions") {
+      stream => stream.planInputPartitions(offset, offset)
+    }
+
+    testMicroBatchStreamError("commitSource", "error committing offset") {
+      stream => stream.commit(offset)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to