This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c8c249204178 [SPARK-47793][SS][PYTHON] Implement 
SimpleDataSourceStreamReader for python streaming data source
c8c249204178 is described below

commit c8c2492041782b9be7f10647191dcd0d5f6a5a8a
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Tue Apr 30 22:08:32 2024 +0900

    [SPARK-47793][SS][PYTHON] Implement SimpleDataSourceStreamReader for python 
streaming data source
    
    ### What changes were proposed in this pull request?
    SimpleDataSourceStreamReader is a simplified version of the 
DataSourceStreamReader interface.
    
    There are 3 functions that needs to be defined
    
    1. Read data and return the end offset.
    _def read(self, start: Offset) -> (Iterator[Tuple], Offset)_
    
    2. Read data between start and end offset, this is required for exactly 
once read.
    _def readBetweenOffset(self, start: Offset, end: Offset) -> Iterator[Tuple]_
    
    3. initial start offset of the streaming query.
    _def initialOffset() -> dict_
    
    The implementation wrap the SimpleDataSourceStreamReader instance in a 
DataSourceStreamReader that prefetch and cache data in latestOffset. The record 
prefetched in python process will be sent to JVM as arrow record batches in 
planInputPartitions() and cached by block manager and read by partition reader 
from executor later..
    
    ### Why are the changes needed?
    Compared to DataSourceStreamReader interface, the simplified interface has 
some advantages.
    It doesn’t require developers to reason about data partitioning.
    It doesn’t require getting the latest offset before reading data.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add unit test and integration test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45977 from chaoqin-li1123/simple_reader_impl.
    
    Lead-authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Co-authored-by: chaoqin-li1123 
<55518381+chaoqin-li1...@users.noreply.github.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../scala/org/apache/spark/storage/BlockId.scala   |   8 +
 python/pyspark/sql/datasource.py                   | 129 ++++++++-
 python/pyspark/sql/datasource_internal.py          | 146 ++++++++++
 .../streaming/python_streaming_source_runner.py    |  58 +++-
 python/pyspark/sql/worker/plan_data_source_read.py | 142 +++++-----
 .../v2/python/PythonMicroBatchStream.scala         |  61 +++-
 .../datasources/v2/python/PythonScan.scala         |   3 +
 .../PythonStreamingPartitionReaderFactory.scala    |  89 ++++++
 .../python/PythonStreamingSourceRunner.scala       |  57 +++-
 .../python/PythonStreamingDataSourceSuite.scala    | 307 ++++++++++++++++++++-
 10 files changed, 911 insertions(+), 89 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 585d9a886b47..6eb015d56b2c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -170,6 +170,11 @@ case class StreamBlockId(streamId: Int, uniqueId: Long) 
extends BlockId {
   override def name: String = "input-" + streamId + "-" + uniqueId
 }
 
+@DeveloperApi
+case class PythonStreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+  override def name: String = "python-stream-" + streamId + "-" + uniqueId
+}
+
 /** Id associated with temporary local data managed as blocks. Not 
serializable. */
 private[spark] case class TempLocalBlockId(id: UUID) extends BlockId {
   override def name: String = "temp_local_" + id
@@ -213,6 +218,7 @@ object BlockId {
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
+  val PYTHON_STREAM = "python-stream-([0-9]+)-([0-9]+)".r
   val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r
   val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
   val TEST = "test_(.*)".r
@@ -250,6 +256,8 @@ object BlockId {
       TaskResultBlockId(taskId.toLong)
     case STREAM(streamId, uniqueId) =>
       StreamBlockId(streamId.toInt, uniqueId.toLong)
+    case PYTHON_STREAM(streamId, uniqueId) =>
+      PythonStreamBlockId(streamId.toInt, uniqueId.toLong)
     case TEMP_LOCAL(uuid) =>
       TempLocalBlockId(UUID.fromString(uuid))
     case TEMP_SHUFFLE(uuid) =>
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index c08b5b7af77f..6cac7e35ff41 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -183,11 +183,36 @@ class DataSource(ABC):
             message_parameters={"feature": "streamWriter"},
         )
 
+    def simpleStreamReader(self, schema: StructType) -> 
"SimpleDataSourceStreamReader":
+        """
+        Returns a :class:`SimpleDataSourceStreamReader` instance for reading 
data.
+
+        One of simpleStreamReader() and streamReader() must be implemented for 
readable streaming
+        data source. Spark will check whether streamReader() is implemented, 
if yes, create a
+        DataSourceStreamReader to read data. simpleStreamReader() will only be 
invoked when
+        streamReader() is not implemented.
+
+        Parameters
+        ----------
+        schema : :class:`StructType`
+            The schema of the data to be read.
+
+        Returns
+        -------
+        reader : :class:`SimpleDataSourceStreamReader`
+            A reader instance for this data source.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "simpleStreamReader"},
+        )
+
     def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
         """
         Returns a :class:`DataSourceStreamReader` instance for reading 
streaming data.
 
-        The implementation is required for readable streaming data sources.
+        One of simpleStreamReader() and streamReader() must be implemented for 
readable streaming
+        data source.
 
         Parameters
         ----------
@@ -396,8 +421,10 @@ class DataSourceStreamReader(ABC):
 
     def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
         """
-        Returns a list of InputPartition  given the start and end offsets. 
Each InputPartition
-        represents a data split that can be processed by one Spark task.
+        Returns a list of InputPartition given the start and end offsets. Each 
InputPartition
+        represents a data split that can be processed by one Spark task. This 
may be called with
+        an empty offset range when start == end, in that case the method 
should return
+        an empty sequence of InputPartition.
 
         Parameters
         ----------
@@ -469,6 +496,102 @@ class DataSourceStreamReader(ABC):
         ...
 
 
+class SimpleDataSourceStreamReader(ABC):
+    """
+    A base class for simplified streaming data source readers.
+    Compared to :class:`DataSourceStreamReader`, 
:class:`SimpleDataSourceStreamReader` doesn't
+    require planning data partition. Also, the read api of 
:class:`SimpleDataSourceStreamReader`
+    allows reading data and planning the latest offset at the same time.
+
+    Because  :class:`SimpleDataSourceStreamReader` read records in Spark 
driver node to determine
+    end offset of each batch without partitioning, it is only supposed to be 
used in
+    lightweight use cases where input rate and batch size is small.
+    Use :class:`DataSourceStreamReader` when read throughput is high and can't 
be handled
+    by a single process.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
+        """
+        Read all available data from start offset and return the offset that 
next read attempt
+        starts from.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        Returns
+        -------
+        A :class:`Tuple` of an iterator of :class:`Tuple` and a dict\\s
+            The iterator contains all the available records after start offset.
+            The dict is the end offset of this read attempt and the start of 
next read attempt.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read"},
+        )
+
+    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
+        """
+        Read all available data from specific start offset and end offset.
+        This is invoked during failure recovery to re-read a batch 
deterministically.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        end : dict
+            The offset where the reading stop.
+
+        Returns
+        -------
+        iterator of :class:`Tuple`\\s
+            All the records between start offset and end offset.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "readBetweenOffsets"},
+        )
+
+    def commit(self, end: dict) -> None:
+        """
+        Informs the source that Spark has completed processing all data for 
offsets less than or
+        equal to `end` and will only request offsets greater than `end` in the 
future.
+
+        Parameters
+        ----------
+        end : dict
+            The latest offset that the streaming query has processed for this 
source.
+        """
+        ...
+
+
 class DataSourceWriter(ABC):
     """
     A base class for data source writers. Data source writers are responsible 
for saving
diff --git a/python/pyspark/sql/datasource_internal.py 
b/python/pyspark/sql/datasource_internal.py
new file mode 100644
index 000000000000..6df0be4192ec
--- /dev/null
+++ b/python/pyspark/sql/datasource_internal.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last planned batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through 
readBetweenOffsets(start, end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.
+        if self.current_offset is None:
+            self.current_offset = self.initialOffset()
+        (iter, end) = self.simple_reader.read(self.current_offset)
+        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+        self.current_offset = end
+        return end
+
+    def commit(self, end: dict) -> None:
+        if self.current_offset is None:
+            self.current_offset = end
+
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if end_idx > 0:
+            # Drop prefetched data for batch that has been committed.
+            self.cache = self.cache[end_idx:]
+        self.simple_reader.commit(end)
+
+    def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+        # when query restart from checkpoint, use the last committed offset as 
the start offset.
+        # This depends on the streaming engine calling planInputPartitions() 
of the last batch
+        # in offset log when query restart.
+        if self.current_offset is None:
+            self.current_offset = end
+        if len(self.cache) > 0:
+            assert self.cache[-1].end == end
+        return [SimpleInputPartition(start, end)]
+
+    def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+        start_idx = -1
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            # There is no convenient way to compare 2 offsets.
+            # Serialize into json string before comparison.
+            if json.dumps(entry.start) == json.dumps(start):
+                start_idx = idx
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if start_idx == -1 or end_idx == -1:
+            return None  # type: ignore[return-value]
+        # Chain all the data iterator between start offset and end offset
+        # need to copy here to avoid exhausting the original data iterator.
+        entries = [copy.copy(entry.iterator) for entry in self.cache[start_idx 
: end_idx + 1]]
+        it = chain(*entries)
+        return it
+
+    def read(
+        self, input_partition: SimpleInputPartition  # type: ignore[override]
+    ) -> Iterator[Tuple]:
+        return self.simple_reader.readBetweenOffsets(input_partition.start, 
input_partition.end)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 8109403b42dd..946344fbc6c9 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -18,9 +18,10 @@
 import os
 import sys
 import json
-from typing import IO
+from typing import IO, Iterator, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.errors import IllegalArgumentException, PySparkAssertionError, 
PySparkRuntimeError
 from pyspark.serializers import (
     read_int,
@@ -29,11 +30,14 @@ from pyspark.serializers import (
     SpecialLengths,
 )
 from pyspark.sql.datasource import DataSource, DataSourceStreamReader
+from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, 
_streamReader
+from pyspark.sql.pandas.serializers import ArrowStreamSerializer
 from pyspark.sql.types import (
     _parse_datatype_json_string,
     StructType,
 )
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches
+from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
     check_python_version,
     read_command,
@@ -49,6 +53,10 @@ LATEST_OFFSET_FUNC_ID = 885
 PARTITIONS_FUNC_ID = 886
 COMMIT_FUNC_ID = 887
 
+PREFETCHED_RECORDS_NOT_FOUND = 0
+NON_EMPTY_PYARROW_RECORD_BATCHES = 1
+EMPTY_PYARROW_RECORD_BATCHES = 2
+
 
 def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
     offset = reader.initialOffset()
@@ -60,7 +68,14 @@ def latest_offset_func(reader: DataSourceStreamReader, 
outfile: IO) -> None:
     write_with_length(json.dumps(offset).encode("utf-8"), outfile)
 
 
-def partitions_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) 
-> None:
+def partitions_func(
+    reader: DataSourceStreamReader,
+    data_source: DataSource,
+    schema: StructType,
+    max_arrow_batch_size: int,
+    infile: IO,
+    outfile: IO,
+) -> None:
     start_offset = json.loads(utf8_deserializer.loads(infile))
     end_offset = json.loads(utf8_deserializer.loads(infile))
     partitions = reader.partitions(start_offset, end_offset)
@@ -68,6 +83,14 @@ def partitions_func(reader: DataSourceStreamReader, infile: 
IO, outfile: IO) ->
     write_int(len(partitions), outfile)
     for partition in partitions:
         pickleSer._write_with_length(partition, outfile)
+    if isinstance(reader, _SimpleStreamReaderWrapper):
+        it = reader.getCache(start_offset, end_offset)
+        if it is None:
+            write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)
+        else:
+            send_batch_func(it, outfile, schema, max_arrow_batch_size, 
data_source)
+    else:
+        write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)
 
 
 def commit_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> 
None:
@@ -76,6 +99,23 @@ def commit_func(reader: DataSourceStreamReader, infile: IO, 
outfile: IO) -> None
     write_int(0, outfile)
 
 
+def send_batch_func(
+    rows: Iterator[Tuple],
+    outfile: IO,
+    schema: StructType,
+    max_arrow_batch_size: int,
+    data_source: DataSource,
+) -> None:
+    batches = list(records_to_arrow_batches(rows, max_arrow_batch_size, 
schema, data_source))
+    if len(batches) != 0:
+        write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile)
+        write_int(SpecialLengths.START_ARROW_STREAM, outfile)
+        serializer = ArrowStreamSerializer()
+        serializer.dump_stream(batches, outfile)
+    else:
+        write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile)
+
+
 def main(infile: IO, outfile: IO) -> None:
     try:
         check_python_version(infile)
@@ -110,9 +150,15 @@ def main(infile: IO, outfile: IO) -> None:
                 },
             )
 
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+
         # Instantiate data source reader.
         try:
-            reader = data_source.streamReader(schema=schema)
+            reader = _streamReader(data_source, schema)
             # Initialization succeed.
             write_int(0, outfile)
             outfile.flush()
@@ -125,7 +171,9 @@ def main(infile: IO, outfile: IO) -> None:
                 elif func_id == LATEST_OFFSET_FUNC_ID:
                     latest_offset_func(reader, outfile)
                 elif func_id == PARTITIONS_FUNC_ID:
-                    partitions_func(reader, infile, outfile)
+                    partitions_func(
+                        reader, data_source, schema, max_arrow_batch_size, 
infile, outfile
+                    )
                 elif func_id == COMMIT_FUNC_ID:
                     commit_func(reader, infile, outfile)
                 else:
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index 8a8b2cab91d8..be7ebd20f180 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -18,8 +18,9 @@
 import os
 import sys
 import functools
+import pyarrow as pa
 from itertools import islice
-from typing import IO, List, Iterator, Iterable
+from typing import IO, List, Iterator, Iterable, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -32,6 +33,7 @@ from pyspark.serializers import (
 from pyspark.sql import Row
 from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
 from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.datasource_internal import _streamReader
 from pyspark.sql.pandas.types import to_arrow_schema
 from pyspark.sql.types import (
     _parse_datatype_json_string,
@@ -51,6 +53,78 @@ from pyspark.worker_util import (
 )
 
 
+def records_to_arrow_batches(
+    output_iter: Iterator[Tuple],
+    max_arrow_batch_size: int,
+    return_type: StructType,
+    data_source: DataSource,
+) -> Iterable[pa.RecordBatch]:
+    """
+    Convert an iterator of Python tuples to an iterator of pyarrow record 
batches.
+
+    For each python tuple, check the types of each field and append it to the 
records batch.
+
+    """
+
+    def batched(iterator: Iterator, n: int) -> Iterator:
+        return iter(functools.partial(lambda it: list(islice(it, n)), 
iterator), [])
+
+    pa_schema = to_arrow_schema(return_type)
+    column_names = return_type.fieldNames()
+    column_converters = [
+        LocalDataToArrowConversion._create_converter(field.dataType) for field 
in return_type.fields
+    ]
+    # Convert the results from the `reader.read` method to an iterator of 
arrow batches.
+    num_cols = len(column_names)
+    col_mapping = {name: i for i, name in enumerate(column_names)}
+    col_name_set = set(column_names)
+    for batch in batched(output_iter, max_arrow_batch_size):
+        pylist: List[List] = [[] for _ in range(num_cols)]
+        for result in batch:
+            # Validate the output row schema.
+            if hasattr(result, "__len__") and len(result) != num_cols:
+                raise PySparkRuntimeError(
+                    error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
+                    message_parameters={
+                        "expected": str(num_cols),
+                        "actual": str(len(result)),
+                    },
+                )
+
+            # Validate the output row type.
+            if not isinstance(result, (list, tuple)):
+                raise PySparkRuntimeError(
+                    error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
+                    message_parameters={
+                        "type": type(result).__name__,
+                        "name": data_source.name(),
+                        "supported_types": "tuple, list, 
`pyspark.sql.types.Row`",
+                    },
+                )
+
+            # Assign output values by name of the field, not position, if the 
result is a
+            # named `Row` object.
+            if isinstance(result, Row) and hasattr(result, "__fields__"):
+                # Check if the names are the same as the schema.
+                if set(result.__fields__) != col_name_set:
+                    raise PySparkRuntimeError(
+                        
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+                        message_parameters={
+                            "expected": str(column_names),
+                            "actual": str(result.__fields__),
+                        },
+                    )
+                # Assign the values by name.
+                for name in column_names:
+                    idx = col_mapping[name]
+                    pylist[idx].append(column_converters[idx](result[name]))
+            else:
+                for col in range(num_cols):
+                    pylist[col].append(column_converters[col](result[col]))
+        batch = pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
+        yield batch
+
+
 def main(infile: IO, outfile: IO) -> None:
     """
     Main method for planning a data source read.
@@ -131,25 +205,16 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Instantiate data source reader.
         reader = (
-            data_source.streamReader(schema=schema)
+            _streamReader(data_source, schema)
             if is_streaming
             else data_source.reader(schema=schema)
         )
 
-        # Wrap the data source read logic in an mapInArrow UDF.
-        import pyarrow as pa
-
         # Create input converter.
         converter = ArrowTableToRowsConversion._create_converter(BinaryType())
 
         # Create output converter.
         return_type = schema
-        pa_schema = to_arrow_schema(return_type)
-        column_names = return_type.fieldNames()
-        column_converters = [
-            LocalDataToArrowConversion._create_converter(field.dataType)
-            for field in return_type.fields
-        ]
 
         def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> 
Iterable[pa.RecordBatch]:
             partition_bytes = None
@@ -189,58 +254,9 @@ def main(infile: IO, outfile: IO) -> None:
                     },
                 )
 
-            def batched(iterator: Iterator, n: int) -> Iterator:
-                return iter(functools.partial(lambda it: list(islice(it, n)), 
iterator), [])
-
-            # Convert the results from the `reader.read` method to an iterator 
of arrow batches.
-            num_cols = len(column_names)
-            col_mapping = {name: i for i, name in enumerate(column_names)}
-            col_name_set = set(column_names)
-            for batch in batched(output_iter, max_arrow_batch_size):
-                pylist: List[List] = [[] for _ in range(num_cols)]
-                for result in batch:
-                    # Validate the output row schema.
-                    if hasattr(result, "__len__") and len(result) != num_cols:
-                        raise PySparkRuntimeError(
-                            error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
-                            message_parameters={
-                                "expected": str(num_cols),
-                                "actual": str(len(result)),
-                            },
-                        )
-
-                    # Validate the output row type.
-                    if not isinstance(result, (list, tuple)):
-                        raise PySparkRuntimeError(
-                            error_class="DATA_SOURCE_INVALID_RETURN_TYPE",
-                            message_parameters={
-                                "type": type(result).__name__,
-                                "name": data_source.name(),
-                                "supported_types": "tuple, list, 
`pyspark.sql.types.Row`",
-                            },
-                        )
-
-                    # Assign output values by name of the field, not position, 
if the result is a
-                    # named `Row` object.
-                    if isinstance(result, Row) and hasattr(result, 
"__fields__"):
-                        # Check if the names are the same as the schema.
-                        if set(result.__fields__) != col_name_set:
-                            raise PySparkRuntimeError(
-                                
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
-                                message_parameters={
-                                    "expected": str(column_names),
-                                    "actual": str(result.__fields__),
-                                },
-                            )
-                        # Assign the values by name.
-                        for name in column_names:
-                            idx = col_mapping[name]
-                            
pylist[idx].append(column_converters[idx](result[name]))
-                    else:
-                        for col in range(num_cols):
-                            
pylist[col].append(column_converters[col](result[col]))
-
-                yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
+            return records_to_arrow_batches(
+                output_iter, max_arrow_batch_size, return_type, data_source
+            )
 
         command = (data_source_read_func, return_type)
         pickleSer._write_with_length(command, outfile)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 71e6c29bc299..0fc1df4cd1e9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -16,12 +16,15 @@
  */
 package org.apache.spark.sql.execution.datasources.v2.python
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
-import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, 
MicroBatchStream, Offset}
+import 
org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
 import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.storage.{PythonStreamBlockId, StorageLevel}
 
 case class PythonStreamingSourceOffset(json: String) extends Offset
 
@@ -30,11 +33,22 @@ class PythonMicroBatchStream(
     shortName: String,
     outputSchema: StructType,
     options: CaseInsensitiveStringMap
-  ) extends MicroBatchStream with Logging {
+  )
+  extends MicroBatchStream
+  with Logging
+  with AcceptsLatestSeenOffset {
   private def createDataSourceFunc =
     ds.source.createPythonFunction(
       ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
 
+  private val streamId = nextStreamId
+  private var nextBlockId = 0L
+
+  // planInputPartitions() maybe be called multiple times for the current 
microbatch.
+  // Cache the result of planInputPartitions() because it may involve sending 
data
+  // from python to JVM.
+  private var cachedInputPartition: Option[(String, String, 
PythonStreamingInputPartition)] = None
+
   private val runner: PythonStreamingSourceRunner =
     new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
   runner.init()
@@ -44,9 +58,35 @@ class PythonMicroBatchStream(
   override def latestOffset(): Offset = 
PythonStreamingSourceOffset(runner.latestOffset())
 
   override def planInputPartitions(start: Offset, end: Offset): 
Array[InputPartition] = {
-    runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
-      end.asInstanceOf[PythonStreamingSourceOffset].json)
-      .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+    val startOffsetJson = start.asInstanceOf[PythonStreamingSourceOffset].json
+    val endOffsetJson = end.asInstanceOf[PythonStreamingSourceOffset].json
+
+    if (cachedInputPartition.exists(p => p._1 == startOffsetJson && p._2 == 
endOffsetJson)) {
+      return Array(cachedInputPartition.get._3)
+    }
+
+    val (partitions, rows) = runner.partitions(startOffsetJson, endOffsetJson)
+    if (rows.isDefined) {
+      // Only SimpleStreamReader without partitioning prefetch data.
+      assert(partitions.length == 1)
+      nextBlockId = nextBlockId + 1
+      val blockId = PythonStreamBlockId(streamId, nextBlockId)
+      SparkEnv.get.blockManager.putIterator(
+        blockId, rows.get, StorageLevel.MEMORY_AND_DISK_SER, true)
+      val partition = PythonStreamingInputPartition(0, partitions.head, 
Some(blockId))
+      cachedInputPartition.foreach(_._3.dropCache())
+      cachedInputPartition = Some((startOffsetJson, endOffsetJson, partition))
+      Array(partition)
+    } else {
+      partitions.zipWithIndex
+        .map(p => PythonStreamingInputPartition(p._2, p._1, None))
+    }
+  }
+
+  override def setLatestSeenOffset(offset: Offset): Unit = {
+    // Call planPartition on python with an empty offset range to initialize 
the start offset
+    // for the prefetching of simple reader.
+    runner.partitions(offset.json(), offset.json())
   }
 
   private lazy val readInfo: PythonDataSourceReadInfo = {
@@ -57,7 +97,7 @@ class PythonMicroBatchStream(
   }
 
   override def createReaderFactory(): PartitionReaderFactory = {
-    new PythonPartitionReaderFactory(
+    new PythonStreamingPartitionReaderFactory(
       ds.source, readInfo.func, outputSchema, None)
   }
 
@@ -66,9 +106,18 @@ class PythonMicroBatchStream(
   }
 
   override def stop(): Unit = {
+    cachedInputPartition.foreach(_._3.dropCache())
     runner.stop()
   }
 
   override def deserializeOffset(json: String): Offset = 
PythonStreamingSourceOffset(json)
 }
 
+object PythonMicroBatchStream {
+  private var currentId = 0
+  def nextStreamId: Int = synchronized {
+    currentId = currentId + 1
+    currentId
+  }
+}
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 8fefc8b144a1..8ebb91c01fc5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -41,6 +41,9 @@ class PythonScan(
 
   override def supportedCustomMetrics(): Array[CustomMetric] =
     ds.source.createPythonMetrics()
+
+  override def columnarSupportMode(): Scan.ColumnarSupportMode =
+    Scan.ColumnarSupportMode.UNSUPPORTED
 }
 
 class PythonBatch(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
new file mode 100644
index 000000000000..75a38b8ea622
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.PythonStreamBlockId
+
+
+case class PythonStreamingInputPartition(
+    index: Int,
+    pickedPartition: Array[Byte],
+    blockId: Option[PythonStreamBlockId]) extends InputPartition {
+  def dropCache(): Unit = {
+    blockId.foreach(SparkEnv.get.blockManager.master.removeBlock(_))
+  }
+}
+
+class PythonStreamingPartitionReaderFactory(
+    source: UserDefinedPythonDataSource,
+    pickledReadFunc: Array[Byte],
+    outputSchema: StructType,
+    jobArtifactUUID: Option[String])
+  extends PartitionReaderFactory with Logging {
+
+  override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
+    val part = partition.asInstanceOf[PythonStreamingInputPartition]
+
+    // Maybe read from cached block prefetched by SimpleStreamReader
+    lazy val cachedBlock = if (part.blockId.isDefined) {
+      val block = SparkEnv.get.blockManager.get[InternalRow](part.blockId.get)
+        .map(_.data.asInstanceOf[Iterator[InternalRow]])
+      if (block.isEmpty) {
+        logWarning(s"Prefetched block ${part.blockId} for Python data source 
not found.")
+      }
+      block
+    } else None
+
+    new PartitionReader[InternalRow] {
+
+      private[this] val metrics: Map[String, SQLMetric] = 
PythonCustomMetric.pythonMetrics
+
+      private val outputIter = if (cachedBlock.isEmpty) {
+        // Evaluate the python read UDF if the partition is not cached as 
block.
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledReadFunc,
+          "read_from_data_source",
+          UserDefinedPythonDataSource.readInputSchema,
+          outputSchema,
+          metrics,
+          jobArtifactUUID)
+
+        evaluatorFactory.createEvaluator().eval(
+          part.index, Iterator.single(InternalRow(part.pickedPartition)))
+      } else cachedBlock.get
+
+      override def next(): Boolean = outputIter.hasNext
+
+      override def get(): InternalRow = outputIter.next()
+
+      override def close(): Unit = {}
+
+      override def currentMetricsValues(): Array[CustomTaskMetric] = {
+        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value })
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
index 2ef046f668b6..a512b34db345 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
@@ -23,14 +23,20 @@ import java.io.{BufferedInputStream, BufferedOutputStream, 
DataInputStream, Data
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+
 import org.apache.spark.SparkEnv
 import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.PYTHON_EXEC
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
 
 object PythonStreamingSourceRunner {
   // When the python process for python_streaming_source_runner receives one 
of the
@@ -39,6 +45,11 @@ object PythonStreamingSourceRunner {
   val LATEST_OFFSET_FUNC_ID = 885
   val PARTITIONS_FUNC_ID = 886
   val COMMIT_FUNC_ID = 887
+  // Status code for JVM to decide how to receive prefetched record batches
+  // for simple stream reader.
+  val PREFETCHED_RECORDS_NOT_FOUND = 0
+  val NON_EMPTY_PYARROW_RECORD_BATCHES = 1
+  val EMPTY_PYARROW_RECORD_BATCHES = 2
 }
 
 /**
@@ -102,6 +113,8 @@ class PythonStreamingSourceRunner(
     // Send output schema
     PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
 
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+
     dataOut.flush()
 
     dataIn = new DataInputStream(
@@ -148,7 +161,8 @@ class PythonStreamingSourceRunner(
   /**
    * Invokes partitions(start, end) function of the stream reader and receive 
the return value.
    */
-  def partitions(start: String, end: String): Array[Array[Byte]] = {
+  def partitions(start: String, end: String): (Array[Array[Byte]], 
Option[Iterator[InternalRow]]) =
+  {
     dataOut.writeInt(PARTITIONS_FUNC_ID)
     PythonWorkerUtils.writeUTF(start, dataOut)
     PythonWorkerUtils.writeUTF(end, dataOut)
@@ -165,7 +179,20 @@ class PythonStreamingSourceRunner(
       val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
       pickledPartitions.append(pickledPartition)
     }
-    pickledPartitions.toArray
+    val prefetchedRecordsStatus = dataIn.readInt()
+    val iter: Option[Iterator[InternalRow]] = prefetchedRecordsStatus match {
+      case NON_EMPTY_PYARROW_RECORD_BATCHES => Some(readArrowRecordBatches())
+      case PREFETCHED_RECORDS_NOT_FOUND => None
+      case EMPTY_PYARROW_RECORD_BATCHES => Some(Iterator.empty)
+      case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+        val msg = PythonWorkerUtils.readUTF(dataIn)
+        throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+          action = "planPartitions", msg)
+      case _ =>
+        throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+          action = "planPartitions", s"unknown status code 
$prefetchedRecordsStatus")
+    }
+    (pickledPartitions.toArray, iter)
   }
 
   /**
@@ -200,4 +227,30 @@ class PythonStreamingSourceRunner(
         logError("Exception when trying to kill worker", e)
     }
   }
+
+  private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+    s"stream reader for $pythonExec", 0, Long.MaxValue)
+
+  def readArrowRecordBatches(): Iterator[InternalRow] = {
+    assert(dataIn.readInt() == SpecialLengths.START_ARROW_STREAM)
+    val reader = new ArrowStreamReader(dataIn, allocator)
+    val root = reader.getVectorSchemaRoot()
+    // When input is empty schema can't be read.
+    val schema = ArrowUtils.fromArrowSchema(root.getSchema())
+    assert(schema == outputSchema)
+
+    val vectors = root.getFieldVectors().asScala.map { vector =>
+      new ArrowColumnVector(vector)
+    }.toArray[ColumnVector]
+    val rows = ArrayBuffer[InternalRow]()
+    while (reader.loadNextBatch()) {
+      val batch = new ColumnarBatch(vectors)
+      batch.setNumRows(root.getRowCount)
+      // Need to copy the row because the ColumnarBatch row iterator use
+      // the same underlying Internal row.
+      rows.appendAll(batch.rowIterator().asScala.map(_.copy()))
+    }
+    reader.close(false)
+    rows.iterator
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
index 6f4bd1888fbb..97e6467c3eaf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
 import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
-import org.apache.spark.sql.execution.streaming.{MemoryStream, 
ProcessingTimeTrigger}
+import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, 
OffsetSeqLog, ProcessingTimeTrigger}
 import org.apache.spark.sql.streaming.StreamingQueryException
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -36,11 +36,11 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
 
   val waitTimeout = 15.seconds
 
-  protected def simpleDataStreamReaderScript: String =
+  protected def testDataStreamReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
       |
-      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |class TestDataStreamReader(DataSourceStreamReader):
       |    current = 0
       |    def initialOffset(self):
       |        return {"offset": {"partition-1": 0}}
@@ -57,6 +57,43 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
       |        yield (partition.value,)
       |""".stripMargin
 
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+      |
+      |class SimpleDataStreamReader(SimpleDataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"partition-1": 0}
+      |    def read(self, start: dict):
+      |        start_idx = start["partition-1"]
+      |        it = iter([(i, ) for i in range(start_idx, start_idx + 2)])
+      |        return (it, {"partition-1": start_idx + 2})
+      |    def readBetweenOffsets(self, start: dict, end: dict):
+      |        start_idx = start["partition-1"]
+      |        end_idx = end["partition-1"]
+      |        return iter([(i, ) for i in range(start_idx, end_idx)])
+      |""".stripMargin
+
+  protected def simpleDataStreamReaderWithEmptyBatchScript: String =
+    """
+      |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+      |
+      |class SimpleDataStreamReader(SimpleDataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"partition-1": 0}
+      |    def read(self, start: dict):
+      |        start_idx = start["partition-1"]
+      |        if start_idx % 4 == 0:
+      |            it = iter([(i, ) for i in range(start_idx, start_idx + 2)])
+      |        else:
+      |            it = iter([])
+      |        return (it, {"partition-1": start_idx + 2})
+      |    def readBetweenOffsets(self, start: dict, end: dict):
+      |        start_idx = start["partition-1"]
+      |        end_idx = end["partition-1"]
+      |        return iter([(i, ) for i in range(start_idx, end_idx)])
+      |""".stripMargin
+
   protected def errorDataStreamReaderScript: String =
     """
       |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
@@ -117,11 +154,11 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource
-         |$simpleDataStreamReaderScript
+         |$testDataStreamReaderScript
          |
          |class $dataSourceName(DataSource):
          |    def streamReader(self, schema):
-         |        return SimpleDataStreamReader()
+         |        return TestDataStreamReader()
          |""".stripMargin
     val inputSchema = StructType.fromDDL("input BINARY")
 
@@ -144,7 +181,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     stream.stop()
   }
 
-  test("Read from simple data stream source") {
+  test("SimpleDataSourceStreamReader run query and restart") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
@@ -154,9 +191,260 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
          |class $dataSourceName(DataSource):
          |    def schema(self) -> str:
          |        return "id INT"
-         |    def streamReader(self, schema):
+         |    def simpleStreamReader(self, schema):
          |        return SimpleDataStreamReader()
          |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val df = spark.readStream.format(dataSourceName).load()
+
+      val stopSignal1 = new CountDownLatch(1)
+
+      val q1 = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          if (batchId == 10) stopSignal1.countDown()
+        })
+        .start()
+      stopSignal1.await()
+      assert(q1.recentProgress.forall(_.numInputRows == 2))
+      q1.stop()
+      q1.awaitTermination()
+
+      val stopSignal2 = new CountDownLatch(1)
+      val q2 = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          if (batchId == 20) stopSignal2.countDown()
+        }).start()
+      stopSignal2.await()
+      assert(q2.recentProgress.forall(_.numInputRows == 2))
+      q2.stop()
+      q2.awaitTermination()
+    }
+  }
+
+  // Verify prefetch and cache pattern of SimpleDataSourceStreamReader handle 
empty
+  // data batch correctly.
+  test("SimpleDataSourceStreamReader read empty batch") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderWithEmptyBatchScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def simpleStreamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val df = spark.readStream.format(dataSourceName).load()
+
+      val stopSignal = new CountDownLatch(1)
+
+      val q = df
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          if (batchId % 2 == 0) {
+            checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          } else {
+            assert(df.isEmpty)
+          }
+          if (batchId == 10) stopSignal.countDown()
+        })
+        .start()
+      stopSignal.await()
+      q.stop()
+      q.awaitTermination()
+    }
+  }
+
+  test("SimpleDataSourceStreamReader read exactly once") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def simpleStreamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val outputDir = new File(path, "output")
+      val df = spark.readStream.format(dataSourceName).load()
+      var lastBatch = 0
+      // Restart streaming query multiple times to verify exactly once 
guarantee.
+      for (i <- 1 to 5) {
+
+        if (i % 2 == 0) {
+          // Remove the last entry of commit log to test replaying microbatch 
during restart.
+          val offsetLog = new OffsetSeqLog(
+            spark, new File(checkpointDir, "offsets").getCanonicalPath)
+          val commitLog = new CommitLog(
+            spark, new File(checkpointDir, "commits").getCanonicalPath)
+          commitLog.purgeAfter(offsetLog.getLatest().get._1 - 1)
+        }
+
+        val q = df
+          .writeStream
+          .option("checkpointLocation", checkpointDir.getAbsolutePath)
+          .format("json")
+          .start(outputDir.getAbsolutePath)
+
+        while (q.recentProgress.length < 5) {
+          Thread.sleep(200)
+        }
+        q.stop()
+        q.awaitTermination()
+        lastBatch = q.lastProgress.batchId.toInt
+      }
+      assert(lastBatch > 20)
+      checkAnswer(spark.read.format("json").load(outputDir.getAbsolutePath),
+        (0 to  2 * lastBatch + 1).map(Row(_)))
+    }
+  }
+
+  test("initialOffset() method not implemented in 
SimpleDataSourceStreamReader") {
+    assume(shouldTestPandasUDFs)
+    val initialOffsetNotImplementedScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+         |class ErrorDataStreamReader(SimpleDataSourceStreamReader):
+         |    ...
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def simpleStreamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource =
+      createUserDefinedPythonDataSource(errorDataSourceName, 
initialOffsetNotImplementedScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit 
= {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError(
+      "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") {
+      stream => stream.initialOffset()
+    }
+
+    // User don't need to implement latestOffset for 
SimpleDataSourceStreamReader.
+    // The latestOffset method of simple stream reader invokes initialOffset() 
and read()
+    // So the not implemented method is initialOffset.
+    testMicroBatchStreamError(
+      "latestOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") {
+      stream => stream.latestOffset()
+    }
+  }
+
+  test("read() method throw error in SimpleDataSourceStreamReader") {
+    assume(shouldTestPandasUDFs)
+    val initialOffsetNotImplementedScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+         |class ErrorDataStreamReader(SimpleDataSourceStreamReader):
+         |    def initialOffset(self):
+         |        return {"partition": 1}
+         |    def read(self, start):
+         |        raise Exception("error reading available data")
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def simpleStreamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource =
+      createUserDefinedPythonDataSource(errorDataSourceName, 
initialOffsetNotImplementedScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+
+    def testMicroBatchStreamError(action: String, msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit 
= {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      checkErrorMatchPVals(err,
+        errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
+        parameters = Map(
+          "action" -> action,
+          "msg" -> "(.|\\n)*"
+        ))
+      assert(err.getMessage.contains(msg))
+      assert(err.getMessage.contains("ErrorDataSource"))
+      stream.stop()
+    }
+
+    testMicroBatchStreamError(
+      "latestOffset", "Exception: error reading available data") {
+      stream => stream.latestOffset()
+    }
+  }
+
+  test("Read from test data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$testDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def streamReader(self, schema):
+         |        return TestDataStreamReader()
+         |""".stripMargin
 
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
@@ -188,7 +476,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
          |        self.start = start
          |        self.end = end
          |
-         |class SimpleDataStreamReader(DataSourceStreamReader):
+         |class TestDataStreamReader(DataSourceStreamReader):
          |    current = 0
          |    def initialOffset(self):
          |        return {"offset": 0}
@@ -210,7 +498,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
          |        return "id INT"
          |
          |    def streamReader(self, schema):
-         |        return SimpleDataStreamReader()
+         |        return TestDataStreamReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
@@ -303,7 +591,6 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     assert(err.getMessage.contains("error reading data"))
   }
 
-
   test("Method not implemented in stream reader") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to