This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e3e7135af4df [SPARK-47107][SS][PYTHON] Implement partition reader for 
python streaming data source
e3e7135af4df is described below

commit e3e7135af4df3427f4c61cccfe189f702844e1f5
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Thu Mar 28 06:33:49 2024 +0900

    [SPARK-47107][SS][PYTHON] Implement partition reader for python streaming 
data source
    
    ### What changes were proposed in this pull request?
    Piggy back the PythonPartitionReaderFactory to implement reading a data 
partition for python streaming data source. Add test case to verify that python 
streaming data source can read and process data end to end.
    
    ### Why are the changes needed?
    This is part of the effort to support developing streaming data source in 
python interface.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add integration test to verify data are read and metrics are emitted 
correctly.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45485 from chaoqin-li1123/python_stream_read.
    
    Authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../streaming/python_streaming_source_runner.py    |   2 +-
 python/pyspark/sql/worker/plan_data_source_read.py |  75 +++++----
 .../v2/python/PythonMicroBatchStream.scala         |  16 +-
 .../datasources/v2/python/PythonScan.scala         |   3 +-
 .../datasources/v2/python/PythonTable.scala        |   4 +-
 .../v2/python/UserDefinedPythonDataSource.scala    |  16 +-
 .../spark/sql/streaming/DataStreamReader.scala     |   5 +
 .../python/PythonStreamingDataSourceSuite.scala    | 182 ++++++++++++++++++---
 8 files changed, 238 insertions(+), 65 deletions(-)

diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 8dbac431a8ba..512191866a16 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -141,7 +141,7 @@ def main(infile: IO, outfile: IO) -> None:
             error_msg = "data source {} throw exception: 
{}".format(data_source.name, e)
             raise PySparkRuntimeError(
                 error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
-                message_parameters={"error": error_msg},
+                message_parameters={"msg": error_msg},
             )
         finally:
             reader.stop()
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index 8f1fc1e59a61..3e5105996ed4 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -25,6 +25,7 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
+    read_bool,
     read_int,
     write_int,
     SpecialLengths,
@@ -127,33 +128,14 @@ def main(infile: IO, outfile: IO) -> None:
             f"'{max_arrow_batch_size}'"
         )
 
-        # Instantiate data source reader.
-        reader = data_source.reader(schema=schema)
+        is_streaming = read_bool(infile)
 
-        # Get the partitions if any.
-        try:
-            partitions = reader.partitions()
-            if not isinstance(partitions, list):
-                raise PySparkRuntimeError(
-                    error_class="DATA_SOURCE_TYPE_MISMATCH",
-                    message_parameters={
-                        "expected": "'partitions' to return a list",
-                        "actual": f"'{type(partitions).__name__}'",
-                    },
-                )
-            if not all(isinstance(p, InputPartition) for p in partitions):
-                partition_types = ", ".join([f"'{type(p).__name__}'" for p in 
partitions])
-                raise PySparkRuntimeError(
-                    error_class="DATA_SOURCE_TYPE_MISMATCH",
-                    message_parameters={
-                        "expected": "all elements in 'partitions' to be of 
type 'InputPartition'",
-                        "actual": partition_types,
-                    },
-                )
-            if len(partitions) == 0:
-                partitions = [None]  # type: ignore
-        except NotImplementedError:
-            partitions = [None]  # type: ignore
+        # Instantiate data source reader.
+        reader = (
+            data_source.streamReader(schema=schema)
+            if is_streaming
+            else data_source.reader(schema=schema)
+        )
 
         # Wrap the data source read logic in an mapInArrow UDF.
         import pyarrow as pa
@@ -195,7 +177,7 @@ def main(infile: IO, outfile: IO) -> None:
                 f"but found '{type(partition).__name__}'."
             )
 
-            output_iter = reader.read(partition)  # type: ignore[arg-type]
+            output_iter = reader.read(partition)  # type: ignore[attr-defined]
 
             # Validate the output iterator.
             if not isinstance(output_iter, Iterator):
@@ -264,11 +246,40 @@ def main(infile: IO, outfile: IO) -> None:
         command = (data_source_read_func, return_type)
         pickleSer._write_with_length(command, outfile)
 
-        # Return the serialized partition values.
-        write_int(len(partitions), outfile)
-        for partition in partitions:
-            pickleSer._write_with_length(partition, outfile)
-
+        if not is_streaming:
+            # The partitioning of python batch source read is determined 
before query execution.
+            try:
+                partitions = reader.partitions()  # type: ignore[attr-defined]
+                if not isinstance(partitions, list):
+                    raise PySparkRuntimeError(
+                        error_class="DATA_SOURCE_TYPE_MISMATCH",
+                        message_parameters={
+                            "expected": "'partitions' to return a list",
+                            "actual": f"'{type(partitions).__name__}'",
+                        },
+                    )
+                if not all(isinstance(p, InputPartition) for p in partitions):
+                    partition_types = ", ".join([f"'{type(p).__name__}'" for p 
in partitions])
+                    raise PySparkRuntimeError(
+                        error_class="DATA_SOURCE_TYPE_MISMATCH",
+                        message_parameters={
+                            "expected": "elements in 'partitions' to be of 
type 'InputPartition'",
+                            "actual": partition_types,
+                        },
+                    )
+                if len(partitions) == 0:
+                    partitions = [None]
+            except NotImplementedError:
+                partitions = [None]
+
+            # Return the serialized partition values.
+            write_int(len(partitions), outfile)
+            for partition in partitions:
+                pickleSer._write_with_length(partition, outfile)
+        else:
+            # Send an empty list of partition for stream reader because 
partitions are planned
+            # in each microbatch during query execution.
+            write_int(0, outfile)
     except BaseException as e:
         handle_worker_exception(e, outfile)
         sys.exit(-1)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 4e77f33c24f0..71e6c29bc299 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -25,8 +25,6 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 case class PythonStreamingSourceOffset(json: String) extends Offset
 
-case class PythonStreamingSourcePartition(partition: Array[Byte]) extends 
InputPartition
-
 class PythonMicroBatchStream(
     ds: PythonDataSourceV2,
     shortName: String,
@@ -47,12 +45,20 @@ class PythonMicroBatchStream(
 
   override def planInputPartitions(start: Offset, end: Offset): 
Array[InputPartition] = {
     runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
-      
end.asInstanceOf[PythonStreamingSourceOffset].json).map(PythonStreamingSourcePartition(_))
+      end.asInstanceOf[PythonStreamingSourceOffset].json)
+      .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+  }
+
+  private lazy val readInfo: PythonDataSourceReadInfo = {
+    ds.source.createReadInfoInPython(
+      ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
+      outputSchema,
+      isStreaming = true)
   }
 
   override def createReaderFactory(): PartitionReaderFactory = {
-    // TODO(SPARK-47107): fill in the implementation.
-    null
+    new PythonPartitionReaderFactory(
+      ds.source, readInfo.func, outputSchema, None)
   }
 
   override def commit(end: Offset): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index bcddf66fc161..8fefc8b144a1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -53,7 +53,8 @@ class PythonBatch(
   private lazy val infoInPython: PythonDataSourceReadInfo = {
     ds.source.createReadInfoInPython(
       ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
-      outputSchema)
+      outputSchema,
+      isStreaming = false)
   }
 
   override def planInputPartitions(): Array[InputPartition] =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
index 0476650a60bf..f633e601f424 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability}
-import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, STREAMING_WRITE, TRUNCATE}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, MICRO_BATCH_READ, STREAMING_WRITE, TRUNCATE}
 import org.apache.spark.sql.connector.read.ScanBuilder
 import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
 import org.apache.spark.sql.types.StructType
@@ -32,7 +32,7 @@ class PythonTable(
   override def name(): String = shortName
 
   override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
-    BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE)
+    BATCH_READ, BATCH_WRITE, MICRO_BATCH_READ, STREAMING_WRITE, TRUNCATE)
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
     new PythonScanBuilder(ds, shortName, outputSchema, options)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 0586d1fd4bc1..241d8087fc3c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -67,11 +67,13 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
    */
   def createReadInfoInPython(
       pythonResult: PythonDataSourceCreationResult,
-      outputSchema: StructType): PythonDataSourceReadInfo = {
+      outputSchema: StructType,
+      isStreaming: Boolean): PythonDataSourceReadInfo = {
     new UserDefinedPythonDataSourceReadRunner(
       createPythonFunction(pythonResult.dataSource),
       UserDefinedPythonDataSource.readInputSchema,
-      outputSchema).runInPython()
+      outputSchema,
+      isStreaming).runInPython()
   }
 
   /**
@@ -312,7 +314,8 @@ case class PythonDataSourceReadInfo(
 private class UserDefinedPythonDataSourceReadRunner(
     func: PythonFunction,
     inputSchema: StructType,
-    outputSchema: StructType) extends 
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+    outputSchema: StructType,
+    isStreaming: Boolean) extends 
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
 
   // See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
   override val workerModule = "pyspark.sql.worker.plan_data_source_read"
@@ -329,6 +332,8 @@ private class UserDefinedPythonDataSourceReadRunner(
 
     // Send configurations
     dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+
+    dataOut.writeBoolean(isStreaming)
   }
 
   override protected def receiveFromPython(dataIn: DataInputStream): 
PythonDataSourceReadInfo = {
@@ -346,6 +351,11 @@ private class UserDefinedPythonDataSourceReadRunner(
     // Receive the list of partitions, if any.
     val pickledPartitions = ArrayBuffer.empty[Array[Byte]]
     val numPartitions = dataIn.readInt()
+    if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "read", msg = msg)
+    }
     for (_ <- 0 until numPartitions) {
       val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
       pickledPartitions.append(pickledPartition)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 6e24e14fb1eb..24d769fc8fc8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import 
org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, 
FileDataSourceV2}
+import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
 import org.apache.spark.sql.execution.datasources.xml.XmlUtils.checkXmlSchema
 import org.apache.spark.sql.execution.streaming.StreamingRelation
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -178,6 +179,10 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
         val finalOptions = sessionOptions.filter { case (k, _) => 
!optionsWithPath.contains(k) } ++
             optionsWithPath.originalMap
         val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
+        provider match {
+          case p: PythonDataSourceV2 => p.setShortName(source)
+          case _ =>
+        }
         val table = DataSourceV2Utils.getTableFromProvider(provider, 
dsOptions, userSpecifiedSchema)
         import 
org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
         table match {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
index 42eaa492be73..6f4bd1888fbb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.File
+import java.util.concurrent.CountDownLatch
 
 import scala.concurrent.duration._
 
@@ -24,7 +25,8 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
 import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
-import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.{MemoryStream, 
ProcessingTimeTrigger}
+import org.apache.spark.sql.streaming.StreamingQueryException
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
@@ -39,10 +41,12 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
       |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
       |
       |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    current = 0
       |    def initialOffset(self):
       |        return {"offset": {"partition-1": 0}}
       |    def latestOffset(self):
-      |        return {"offset": {"partition-1": 2}}
+      |        self.current += 2
+      |        return {"offset": {"partition-1": self.current}}
       |    def partitions(self, start: dict, end: dict):
       |        start_index = start["offset"]["partition-1"]
       |        end_index = end["offset"]["partition-1"]
@@ -50,9 +54,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
       |    def commit(self, end: dict):
       |        1 + 2
       |    def read(self, partition):
-      |        yield (0, partition.value)
-      |        yield (1, partition.value)
-      |        yield (2, partition.value)
+      |        yield (partition.value,)
       |""".stripMargin
 
   protected def errorDataStreamReaderScript: String =
@@ -110,7 +112,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
 
   private val errorDataSourceName = "ErrorDataSource"
 
-  test("simple data stream source") {
+  test("Test PythonMicroBatchStream") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
@@ -130,40 +132,178 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     val stream = new PythonMicroBatchStream(
       pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
 
-    val initialOffset = stream.initialOffset()
-    assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}")
-    for (_ <- 1 to 50) {
-      val offset = stream.latestOffset()
-      assert(offset.json == "{\"offset\": {\"partition-1\": 2}}")
-      assert(stream.planInputPartitions(initialOffset, offset).size == 2)
-      stream.commit(offset)
+    var startOffset = stream.initialOffset()
+    assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}")
+    for (i <- 1 to 50) {
+      val endOffset = stream.latestOffset()
+      assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""")
+      assert(stream.planInputPartitions(startOffset, endOffset).size == 2)
+      stream.commit(endOffset)
+      startOffset = endOffset
     }
     stream.stop()
   }
 
+  test("Read from simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val stopSignal = new CountDownLatch(1)
+
+    val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+      // checkAnswer may materialize the dataframe more than once
+      // Cache here to make sure the numInputRows metrics is consistent.
+      df.cache()
+      checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+      if (batchId > 30) stopSignal.countDown()
+    }).trigger(ProcessingTimeTrigger(0)).start()
+    stopSignal.await()
+    assert(q.recentProgress.forall(_.numInputRows == 2))
+    q.stop()
+    q.awaitTermination()
+  }
+
+  test("Streaming data source read with custom partitions") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, 
DataSourceStreamReader, InputPartition
+         |class RangePartition(InputPartition):
+         |    def __init__(self, start, end):
+         |        self.start = start
+         |        self.end = end
+         |
+         |class SimpleDataStreamReader(DataSourceStreamReader):
+         |    current = 0
+         |    def initialOffset(self):
+         |        return {"offset": 0}
+         |    def latestOffset(self):
+         |        self.current += 2
+         |        return {"offset": self.current}
+         |    def partitions(self, start: dict, end: dict):
+         |        return [RangePartition(start["offset"], end["offset"])]
+         |    def commit(self, end: dict):
+         |        1 + 2
+         |    def read(self, partition: RangePartition):
+         |        start, end = partition.start, partition.end
+         |        for i in range(start, end):
+         |            yield (i, )
+         |
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val stopSignal = new CountDownLatch(1)
+
+    val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+      // checkAnswer may materialize the dataframe more than once
+      // Cache here to make sure the numInputRows metrics is consistent.
+      df.cache()
+      checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+      if (batchId > 30) stopSignal.countDown()
+    }).trigger(ProcessingTimeTrigger(0)).start()
+    stopSignal.await()
+    assert(q.recentProgress.forall(_.numInputRows == 2))
+    q.stop()
+    q.awaitTermination()
+  }
+
   test("Error creating stream reader") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource
          |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
          |    def streamReader(self, schema):
          |        raise Exception("error creating stream reader")
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(
       name = dataSourceName, pythonScript = dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
-    val pythonDs = new PythonDataSourceV2
-    pythonDs.setShortName("SimpleDataSource")
-    val inputSchema = StructType.fromDDL("input BINARY")
-    val err = intercept[AnalysisException] {
-      new PythonMicroBatchStream(
-        pythonDs, dataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+
+    val err = intercept[StreamingQueryException] {
+      val q = spark.readStream.format(dataSourceName).load()
+        .writeStream.format("console").start()
+      q.awaitTermination()
     }
-    assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR")
+    assert(err.getErrorClass == "STREAM_FAILED")
     assert(err.getMessage.contains("error creating stream reader"))
   }
 
+  test("Streaming data source read error") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, 
DataSourceStreamReader, InputPartition
+         |class RangePartition(InputPartition):
+         |    def __init__(self, start, end):
+         |        self.start = start
+         |        self.end = end
+         |
+         |class SimpleDataStreamReader(DataSourceStreamReader):
+         |    current = 0
+         |    def initialOffset(self):
+         |        return {"offset": "0"}
+         |    def latestOffset(self):
+         |        self.current += 2
+         |        return {"offset": str(self.current)}
+         |    def partitions(self, start: dict, end: dict):
+         |        return [RangePartition(int(start["offset"]), 
int(end["offset"]))]
+         |    def commit(self, end: dict):
+         |        1 + 2
+         |    def read(self, partition: RangePartition):
+         |        raise Exception("error reading data")
+         |
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val err = intercept[StreamingQueryException] {
+      val q = df.writeStream.foreachBatch((df: DataFrame, _: Long) => {
+        df.count()
+        ()
+      }).start()
+      q.awaitTermination()
+    }
+    assert(err.getMessage.contains("error reading data"))
+  }
+
+
   test("Method not implemented in stream reader") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
@@ -237,7 +377,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     spark.dataSource.registerPython(errorDataSourceName, dataSource)
     val pythonDs = new PythonDataSourceV2
     pythonDs.setShortName("ErrorDataSource")
-    val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
+    val offset = PythonStreamingSourceOffset("{\"offset\": 2}")
 
     def testMicroBatchStreamError(action: String, msg: String)
                                  (func: PythonMicroBatchStream => Unit): Unit 
= {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to