HeartSaVioR commented on code in PR #45977:
URL: https://github.com/apache/spark/pull/45977#discussion_r1580561802


##########
python/pyspark/sql/datasource.py:
##########
@@ -469,6 +494,103 @@ def stop(self) -> None:
         ...
 
 
+class SimpleDataSourceStreamReader(ABC):
+    """
+    A base class for simplified streaming data source readers.
+    Compared to :class:`DataSourceStreamReader`, 
:class:`SimpleDataSourceStreamReader` doesn't
+    require planning data partition. Also, the read api of 
:class:`SimpleDataSourceStreamReader`
+    allows reading data and planning the latest offset at the same time.
+
+    Because  :class:`SimpleDataSourceStreamReader` read records in Spark 
driver node to determine
+    end offset of each batch without partitioning, it is only supposed to be 
used in
+    lightweight use cases where input rate and batch size is small.
+    Use :class:`DataSourceStreamReader` when read throughput is high and can't 
be handled
+    by a single process.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
+        """
+        Read all available data from start offset and return the offset that 
next read attempt
+        starts from.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        Returns
+        -------
+        A :class:`Tuple` of an iterator of :class:`Tuple` and a dict\\s
+            The iterator contains all the available records after start offset.
+            The dict is the end offset of this read attempt and the start of 
next read attempt.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read"},
+        )
+
+    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
+        """
+        Read all available data from specific start offset and end offset.
+        This is invoked during failure recovery to re-read a batch 
deterministically
+        in order to achieve exactly once.

Review Comment:
   For stateful workload, this is not only about exactly once but also about 
correctness. Maybe we could just stop from "deterministic batch execution" 
rather than further elaboration.



##########
python/pyspark/sql/datasource.py:
##########
@@ -469,6 +494,103 @@ def stop(self) -> None:
         ...
 
 
+class SimpleDataSourceStreamReader(ABC):
+    """
+    A base class for simplified streaming data source readers.
+    Compared to :class:`DataSourceStreamReader`, 
:class:`SimpleDataSourceStreamReader` doesn't
+    require planning data partition. Also, the read api of 
:class:`SimpleDataSourceStreamReader`
+    allows reading data and planning the latest offset at the same time.
+
+    Because  :class:`SimpleDataSourceStreamReader` read records in Spark 
driver node to determine
+    end offset of each batch without partitioning, it is only supposed to be 
used in
+    lightweight use cases where input rate and batch size is small.
+    Use :class:`DataSourceStreamReader` when read throughput is high and can't 
be handled
+    by a single process.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:

Review Comment:
   I've been missed so far - since we are close to completion, it'd be awesome 
if we can give a try to remove out points of confusion from the doc, e.g. 
inclusive vs exclusive of offset. No need to deal with the doc update in this 
PR, probably worth a JIRA ticket. 



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):

Review Comment:
   If we have to separate out private vs public class, what about above 
classes? Are they needed to be public classes? Here the "private" seems to be 
very unclear. I'm OK if this is some trick to address some gap on Python 
language on scoping. I just wanted to know whether this is a standard practice 
or not.



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.

Review Comment:
   That said, readBetweenOffsets() must be able to work without prefetcher - 
PREFETCHED_RECORDS_NOT_FOUND is not only happening with error case.



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.

Review Comment:
   If you don't have a test case where you have planned batch in offset log and 
have to restart from there, you need to have one. Run several batches, stop the 
query, make the last batch be no-yet-to-be-committed, restart the query. 
prefetcher should not get a request to read from "initial offset", and also 
read request for planned batch should work without relying on prefetcher.



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).

Review Comment:
   nit: comment out of sync, read2 is no longer available



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.

Review Comment:
   Actually this is the hard part of implementing prefetcher for SS data 
source. When the query restarts, we assume that prefetcher would be able to 
start from known committed offset. Unfortunately that is not true. You've 
mentioned that this relies on getBatch trick but that's only applicable with 
DSv1 and it's clearly a hack to address some specific data source.
   
   We have an interface `AcceptsLatestSeenOffset` for this case, but this does 
not give you the last committed offset but the latest seen offset, so Spark 
could request the offset range before this offset. Though it would work if the 
simple data source reader can work with all planned-but-not-yet-committed 
offset range without relying on prefetcher. prefetcher can start prefetching 
with latest seen offset and previous offset range should be covered with 
planned batch(es).



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):

Review Comment:
   What's the best practice or code style? Do we have to document every public 
class or only document for public/developer API? Because this is not even 
developer API from what I understand (3rd party is not required to implement 
this), same with PrefetchedCacheEntry.



##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) -> 
"DataSourceStreamReader":
+    """
+    Fallback to simpleStreamReader() method when streamReader() is not 
implemented.
+    This should be invoked whenever a DataSourceStreamReader needs to be 
created instead of
+    invoking datasource.streamReader() directly.
+    """
+    try:
+        return datasource.streamReader(schema=schema)
+    except PySparkNotImplementedError:
+        return 
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry:
+    def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.
+        if self.current_offset is None:
+            self.current_offset = self.initialOffset()
+        (iter, end) = self.simple_reader.read(self.current_offset)
+        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+        self.current_offset = end
+        return end
+
+    def commit(self, end: dict) -> None:
+        if self.current_offset is None:
+            self.current_offset = end
+
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if end_idx > 0:
+            # Drop prefetched data for batch that has been committed.
+            self.cache = self.cache[end_idx:]
+        self.simple_reader.commit(end)
+
+    def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+        # when query restart from checkpoint, use the last committed offset as 
the start offset.
+        # This depends on the current behavior that streaming engine call 
getBatch on the last

Review Comment:
   This trick only applies to DSv1 AFAIK.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to