This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dff52d649d1 [SPARK-37015][PYTHON] Inline type hints for 
python/pyspark/streaming/dstream.py
dff52d649d1 is described below

commit dff52d649d1e27baf3b107f75636624e0cfe780f
Author: dch nguyen <dchvn....@gmail.com>
AuthorDate: Mon Apr 18 17:38:32 2022 +0200

    [SPARK-37015][PYTHON] Inline type hints for 
python/pyspark/streaming/dstream.py
    
    ### What changes were proposed in this pull request?
    Inline type hints for python/pyspark/streaming/dstream.py
    
    ### Why are the changes needed?
    We can take advantage of static type checking within the functions by 
inlining the type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests
    
    Closes #34324 from dchvn/SPARK-37015.
    
    Lead-authored-by: dch nguyen <dchvn....@gmail.com>
    Co-authored-by: dch nguyen <dgd_contribu...@viettel.com.vn>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 python/pyspark/streaming/context.py  |  22 +--
 python/pyspark/streaming/dstream.py  | 369 +++++++++++++++++++++++++++--------
 python/pyspark/streaming/dstream.pyi | 211 --------------------
 3 files changed, 296 insertions(+), 306 deletions(-)

diff --git a/python/pyspark/streaming/context.py 
b/python/pyspark/streaming/context.py
index 52e5efed063..0be0c7b034a 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -397,12 +397,12 @@ class StreamingContext:
         the transform function parameter will be the same as the order
         of corresponding DStreams in the list.
         """
-        jdstreams = [d._jdstream for d in dstreams]  # type: 
ignore[attr-defined]
+        jdstreams = [d._jdstream for d in dstreams]
         # change the final serializer to sc.serializer
         func = TransformFunction(
             self._sc,
             lambda t, *rdds: transformFunc(rdds),
-            *[d._jrdd_deserializer for d in dstreams],  # type: 
ignore[attr-defined]
+            *[d._jrdd_deserializer for d in dstreams],
         )
 
         assert self._jvm is not None
@@ -419,35 +419,31 @@ class StreamingContext:
             raise ValueError("should have at least one DStream to union")
         if len(dstreams) == 1:
             return dstreams[0]
-        if len(set(s._jrdd_deserializer for s in dstreams)) > 1:  # type: 
ignore[attr-defined]
+        if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
             raise ValueError("All DStreams should have same serializer")
-        if len(set(s._slideDuration for s in dstreams)) > 1:  # type: 
ignore[attr-defined]
+        if len(set(s._slideDuration for s in dstreams)) > 1:
             raise ValueError("All DStreams should have same slide duration")
 
         assert SparkContext._jvm is not None
         jdstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
         jpair_dstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
         gw = SparkContext._gateway
-        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):  # type: 
ignore[attr-defined]
+        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
             cls = jdstream_cls
-        elif is_instance_of(
-            gw, dstreams[0]._jdstream, jpair_dstream_cls  # type: 
ignore[attr-defined]
-        ):
+        elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
             cls = jpair_dstream_cls
         else:
-            cls_name = (
-                dstreams[0]._jdstream.getClass().getCanonicalName()  # type: 
ignore[attr-defined]
-            )
+            cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
             raise TypeError("Unsupported Java DStream class %s" % cls_name)
 
         assert gw is not None
         jdstreams = gw.new_array(cls, len(dstreams))
         for i in range(0, len(dstreams)):
-            jdstreams[i] = dstreams[i]._jdstream  # type: ignore[attr-defined]
+            jdstreams[i] = dstreams[i]._jdstream
         return DStream(
             self._jssc.union(jdstreams),
             self,
-            dstreams[0]._jrdd_deserializer,  # type: ignore[attr-defined]
+            dstreams[0]._jrdd_deserializer,
         )
 
     def addStreamingListener(self, streamingListener: StreamingListener) -> 
None:
diff --git a/python/pyspark/streaming/dstream.py 
b/python/pyspark/streaming/dstream.py
index f445a78bd95..934b3ae5783 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -19,19 +19,45 @@ import operator
 import time
 from itertools import chain
 from datetime import datetime
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    TYPE_CHECKING,
+    cast,
+    overload,
+)
 
 from py4j.protocol import Py4JJavaError
 
-from pyspark import RDD
 from pyspark.storagelevel import StorageLevel
 from pyspark.streaming.util import rddToFileName, TransformFunction
-from pyspark.rdd import portable_hash
+from pyspark.rdd import portable_hash, RDD
 from pyspark.resultiterable import ResultIterable
+from py4j.java_gateway import JavaObject
+
+if TYPE_CHECKING:
+    from pyspark.serializers import Serializer
+    from pyspark.streaming.context import StreamingContext
 
 __all__ = ["DStream"]
 
+S = TypeVar("S")
+T = TypeVar("T")
+T_co = TypeVar("T_co", covariant=True)
+U = TypeVar("U")
+K = TypeVar("K", bound=Hashable)
+V = TypeVar("V")
+
 
-class DStream:
+class DStream(Generic[T_co]):
     """
     A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
     is a continuous sequence of RDDs (of the same type) representing a
@@ -51,7 +77,12 @@ class DStream:
      - A function that is used to generate an RDD after each time interval
     """
 
-    def __init__(self, jdstream, ssc, jrdd_deserializer):
+    def __init__(
+        self,
+        jdstream: JavaObject,
+        ssc: "StreamingContext",
+        jrdd_deserializer: "Serializer",
+    ):
         self._jdstream = jdstream
         self._ssc = ssc
         self._sc = ssc._sc
@@ -59,76 +90,94 @@ class DStream:
         self.is_cached = False
         self.is_checkpointed = False
 
-    def context(self):
+    def context(self) -> "StreamingContext":
         """
         Return the StreamingContext associated with this DStream
         """
         return self._ssc
 
-    def count(self):
+    def count(self) -> "DStream[int]":
         """
         Return a new DStream in which each RDD has a single element
         generated by counting each RDD of this DStream.
         """
         return self.mapPartitions(lambda i: [sum(1 for _ in 
i)]).reduce(operator.add)
 
-    def filter(self, f):
+    def filter(self: "DStream[T]", f: Callable[[T], bool]) -> "DStream[T]":
         """
         Return a new DStream containing only the elements that satisfy 
predicate.
         """
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[T]:
             return filter(f, iterator)
 
         return self.mapPartitions(func, True)
 
-    def flatMap(self, f, preservesPartitioning=False):
+    def flatMap(
+        self: "DStream[T]",
+        f: Callable[[T], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> "DStream[U]":
         """
         Return a new DStream by applying a function to all elements of
         this DStream, and then flattening the results
         """
 
-        def func(s, iterator):
+        def func(s: int, iterator: Iterable[T]) -> Iterable[U]:
             return chain.from_iterable(map(f, iterator))
 
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
-    def map(self, f, preservesPartitioning=False):
+    def map(
+        self: "DStream[T]", f: Callable[[T], U], preservesPartitioning: bool = 
False
+    ) -> "DStream[U]":
         """
         Return a new DStream by applying a function to each element of DStream.
         """
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[U]:
             return map(f, iterator)
 
         return self.mapPartitions(func, preservesPartitioning)
 
-    def mapPartitions(self, f, preservesPartitioning=False):
+    def mapPartitions(
+        self: "DStream[T]",
+        f: Callable[[Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> "DStream[U]":
         """
         Return a new DStream in which each RDD is generated by applying
         mapPartitions() to each RDDs of this DStream.
         """
 
-        def func(s, iterator):
+        def func(s: int, iterator: Iterable[T]) -> Iterable[U]:
             return f(iterator)
 
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
-    def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+    def mapPartitionsWithIndex(
+        self: "DStream[T]",
+        f: Callable[[int, Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> "DStream[U]":
         """
         Return a new DStream in which each RDD is generated by applying
         mapPartitionsWithIndex() to each RDDs of this DStream.
         """
         return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, 
preservesPartitioning))
 
-    def reduce(self, func):
+    def reduce(self: "DStream[T]", func: Callable[[T, T], T]) -> "DStream[T]":
         """
         Return a new DStream in which each RDD has a single element
         generated by reducing each RDD of this DStream.
         """
         return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda 
x: x[1])
 
-    def reduceByKey(self, func, numPartitions=None):
+    def reduceByKey(
+        self: "DStream[Tuple[K, V]]",
+        func: Callable[[V, V], V],
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, V]]":
         """
         Return a new DStream by applying reduceByKey to each RDD.
         """
@@ -136,40 +185,62 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.combineByKey(lambda x: x, func, func, numPartitions)
 
-    def combineByKey(self, createCombiner, mergeValue, mergeCombiners, 
numPartitions=None):
+    def combineByKey(
+        self: "DStream[Tuple[K, V]]",
+        createCombiner: Callable[[V], U],
+        mergeValue: Callable[[U, V], U],
+        mergeCombiners: Callable[[U, U], U],
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, U]]":
         """
         Return a new DStream by applying combineByKey to each RDD.
         """
         if numPartitions is None:
             numPartitions = self._sc.defaultParallelism
 
-        def func(rdd):
+        def func(rdd: RDD[Tuple[K, V]]) -> RDD[Tuple[K, U]]:
             return rdd.combineByKey(createCombiner, mergeValue, 
mergeCombiners, numPartitions)
 
         return self.transform(func)
 
-    def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+    def partitionBy(
+        self: "DStream[Tuple[K, V]]",
+        numPartitions: int,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "DStream[Tuple[K, V]]":
         """
         Return a copy of the DStream in which each RDD are partitioned
         using the specified partitioner.
         """
         return self.transform(lambda rdd: rdd.partitionBy(numPartitions, 
partitionFunc))
 
-    def foreachRDD(self, func):
+    @overload
+    def foreachRDD(self: "DStream[T]", func: Callable[[RDD[T]], None]) -> None:
+        ...
+
+    @overload
+    def foreachRDD(self: "DStream[T]", func: Callable[[datetime, RDD[T]], 
None]) -> None:
+        ...
+
+    def foreachRDD(
+        self: "DStream[T]",
+        func: Union[Callable[[RDD[T]], None], Callable[[datetime, RDD[T]], 
None]],
+    ) -> None:
         """
         Apply a function to each RDD in this DStream.
         """
         if func.__code__.co_argcount == 1:
             old_func = func
 
-            def func(_, rdd):
-                return old_func(rdd)
+            def func(_: datetime, rdd: "RDD[T]") -> None:
+                return old_func(rdd)  # type: ignore[call-arg, arg-type]
 
         jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
+        assert self._ssc._jvm is not None
         api = self._ssc._jvm.PythonDStream
         api.callForeachRDD(self._jdstream, jfunc)
 
-    def pprint(self, num=10):
+    def pprint(self, num: int = 10) -> None:
         """
         Print the first num elements of each RDD generated in this DStream.
 
@@ -179,7 +250,7 @@ class DStream:
             the number of elements from the first will be printed.
         """
 
-        def takeAndPrint(time, rdd):
+        def takeAndPrint(time: datetime, rdd: RDD[T]) -> None:
             taken = rdd.take(num + 1)
             print("-------------------------------------------")
             print("Time: %s" % time)
@@ -192,40 +263,42 @@ class DStream:
 
         self.foreachRDD(takeAndPrint)
 
-    def mapValues(self, f):
+    def mapValues(self: "DStream[Tuple[K, V]]", f: Callable[[V], U]) -> 
"DStream[Tuple[K, U]]":
         """
         Return a new DStream by applying a map function to the value of
         each key-value pairs in this DStream without changing the key.
         """
 
-        def map_values_fn(kv):
+        def map_values_fn(kv: Tuple[K, V]) -> Tuple[K, U]:
             return kv[0], f(kv[1])
 
         return self.map(map_values_fn, preservesPartitioning=True)
 
-    def flatMapValues(self, f):
+    def flatMapValues(
+        self: "DStream[Tuple[K, V]]", f: Callable[[V], Iterable[U]]
+    ) -> "DStream[Tuple[K, U]]":
         """
         Return a new DStream by applying a flatmap function to the value
         of each key-value pairs in this DStream without changing the key.
         """
 
-        def flat_map_fn(kv):
+        def flat_map_fn(kv: Tuple[K, V]) -> Iterable[Tuple[K, U]]:
             return ((kv[0], x) for x in f(kv[1]))
 
         return self.flatMap(flat_map_fn, preservesPartitioning=True)
 
-    def glom(self):
+    def glom(self: "DStream[T]") -> "DStream[List[T]]":
         """
         Return a new DStream in which RDD is generated by applying glom()
         to RDD of this DStream.
         """
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[List[T]]:
             yield list(iterator)
 
         return self.mapPartitions(func)
 
-    def cache(self):
+    def cache(self: "DStream[T]") -> "DStream[T]":
         """
         Persist the RDDs of this DStream with the default storage level
         (`MEMORY_ONLY`).
@@ -234,7 +307,7 @@ class DStream:
         self.persist(StorageLevel.MEMORY_ONLY)
         return self
 
-    def persist(self, storageLevel):
+    def persist(self: "DStream[T]", storageLevel: StorageLevel) -> 
"DStream[T]":
         """
         Persist the RDDs of this DStream with the given storage level
         """
@@ -243,7 +316,7 @@ class DStream:
         self._jdstream.persist(javaStorageLevel)
         return self
 
-    def checkpoint(self, interval):
+    def checkpoint(self: "DStream[T]", interval: int) -> "DStream[T]":
         """
         Enable periodic checkpointing of RDDs of this DStream
 
@@ -257,7 +330,9 @@ class DStream:
         self._jdstream.checkpoint(self._ssc._jduration(interval))
         return self
 
-    def groupByKey(self, numPartitions=None):
+    def groupByKey(
+        self: "DStream[Tuple[K, V]]", numPartitions: Optional[int] = None
+    ) -> "DStream[Tuple[K, Iterable[V]]]":
         """
         Return a new DStream by applying groupByKey on each RDD.
         """
@@ -265,20 +340,20 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
 
-    def countByValue(self):
+    def countByValue(self: "DStream[K]") -> "DStream[Tuple[K, int]]":
         """
         Return a new DStream in which each RDD contains the counts of each
         distinct value in each RDD of this DStream.
         """
         return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
 
-    def saveAsTextFiles(self, prefix, suffix=None):
+    def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = None) -> 
None:
         """
         Save each RDD in this DStream as at text file, using string
         representation of elements.
         """
 
-        def saveAsTextFile(t, rdd):
+        def saveAsTextFile(t: Optional[datetime], rdd: RDD[T]) -> None:
             path = rddToFileName(prefix, suffix, t)
             try:
                 rdd.saveAsTextFile(path)
@@ -307,7 +382,20 @@ class DStream:
     #                 raise
     #     return self.foreachRDD(saveAsPickleFile)
 
-    def transform(self, func):
+    @overload
+    def transform(self: "DStream[T]", func: Callable[[RDD[T]], RDD[U]]) -> 
"TransformedDStream[U]":
+        ...
+
+    @overload
+    def transform(
+        self: "DStream[T]", func: Callable[[datetime, RDD[T]], RDD[U]]
+    ) -> "TransformedDStream[U]":
+        ...
+
+    def transform(
+        self: "DStream[T]",
+        func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], 
RDD[U]]],
+    ) -> "TransformedDStream[U]":
         """
         Return a new DStream in which each RDD is generated by applying a 
function
         on each RDD of this DStream.
@@ -318,13 +406,39 @@ class DStream:
         if func.__code__.co_argcount == 1:
             oldfunc = func
 
-            def func(_, rdd):
-                return oldfunc(rdd)
+            def func(_: datetime, rdd: RDD[T]) -> RDD[U]:
+                return oldfunc(rdd)  # type: ignore[arg-type, call-arg]
 
         assert func.__code__.co_argcount == 2, "func should take one or two 
arguments"
         return TransformedDStream(self, func)
 
-    def transformWith(self, func, other, keepSerializer=False):
+    @overload
+    def transformWith(
+        self: "DStream[T]",
+        func: Callable[[RDD[T], RDD[U]], RDD[V]],
+        other: "DStream[U]",
+        keepSerializer: bool = ...,
+    ) -> "DStream[V]":
+        ...
+
+    @overload
+    def transformWith(
+        self: "DStream[T]",
+        func: Callable[[datetime, RDD[T], RDD[U]], RDD[V]],
+        other: "DStream[U]",
+        keepSerializer: bool = ...,
+    ) -> "DStream[V]":
+        ...
+
+    def transformWith(
+        self: "DStream[T]",
+        func: Union[
+            Callable[[RDD[T], RDD[U]], RDD[V]],
+            Callable[[datetime, RDD[T], RDD[U]], RDD[V]],
+        ],
+        other: "DStream[U]",
+        keepSerializer: bool = False,
+    ) -> "DStream[V]":
         """
         Return a new DStream in which each RDD is generated by applying a 
function
         on each RDD of this DStream and 'other' DStream.
@@ -335,31 +449,37 @@ class DStream:
         if func.__code__.co_argcount == 2:
             oldfunc = func
 
-            def func(_, a, b):
-                return oldfunc(a, b)
+            def func(_: datetime, a: RDD[T], b: RDD[U]) -> RDD[V]:
+                return oldfunc(a, b)  # type: ignore[call-arg, arg-type]
 
         assert func.__code__.co_argcount == 3, "func should take two or three 
arguments"
-        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, 
other._jrdd_deserializer)
+        jfunc = TransformFunction(
+            self._sc,
+            func,
+            self._jrdd_deserializer,
+            other._jrdd_deserializer,
+        )
+        assert self._sc._jvm is not None
         dstream = self._sc._jvm.PythonTransformed2DStream(
             self._jdstream.dstream(), other._jdstream.dstream(), jfunc
         )
         jrdd_serializer = self._jrdd_deserializer if keepSerializer else 
self._sc.serializer
         return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
 
-    def repartition(self, numPartitions):
+    def repartition(self: "DStream[T]", numPartitions: int) -> "DStream[T]":
         """
         Return a new DStream with an increased or decreased level of 
parallelism.
         """
         return self.transform(lambda rdd: rdd.repartition(numPartitions))
 
     @property
-    def _slideDuration(self):
+    def _slideDuration(self) -> None:
         """
         Return the slideDuration in seconds of this DStream
         """
         return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
 
-    def union(self, other):
+    def union(self: "DStream[T]", other: "DStream[U]") -> "DStream[Union[T, 
U]]":
         """
         Return a new DStream by unifying data of another DStream with this 
DStream.
 
@@ -373,7 +493,11 @@ class DStream:
             raise ValueError("the two DStream should have same slide duration")
         return self.transformWith(lambda a, b: a.union(b), other, True)
 
-    def cogroup(self, other, numPartitions=None):
+    def cogroup(
+        self: "DStream[Tuple[K, V]]",
+        other: "DStream[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]":
         """
         Return a new DStream by applying 'cogroup' between RDDs of this
         DStream and `other` DStream.
@@ -382,9 +506,16 @@ class DStream:
         """
         if numPartitions is None:
             numPartitions = self._sc.defaultParallelism
-        return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), 
other)
+        return self.transformWith(
+            lambda a, b: a.cogroup(b, numPartitions),
+            other,
+        )
 
-    def join(self, other, numPartitions=None):
+    def join(
+        self: "DStream[Tuple[K, V]]",
+        other: "DStream[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Tuple[V, U]]]":
         """
         Return a new DStream by applying 'join' between RDDs of this DStream 
and
         `other` DStream.
@@ -396,7 +527,11 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
 
-    def leftOuterJoin(self, other, numPartitions=None):
+    def leftOuterJoin(
+        self: "DStream[Tuple[K, V]]",
+        other: "DStream[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Tuple[V, Optional[U]]]]":
         """
         Return a new DStream by applying 'left outer join' between RDDs of 
this DStream and
         `other` DStream.
@@ -408,7 +543,11 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.transformWith(lambda a, b: a.leftOuterJoin(b, 
numPartitions), other)
 
-    def rightOuterJoin(self, other, numPartitions=None):
+    def rightOuterJoin(
+        self: "DStream[Tuple[K, V]]",
+        other: "DStream[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Tuple[Optional[V], U]]]":
         """
         Return a new DStream by applying 'right outer join' between RDDs of 
this DStream and
         `other` DStream.
@@ -420,7 +559,11 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.transformWith(lambda a, b: a.rightOuterJoin(b, 
numPartitions), other)
 
-    def fullOuterJoin(self, other, numPartitions=None):
+    def fullOuterJoin(
+        self: "DStream[Tuple[K, V]]",
+        other: "DStream[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]":
         """
         Return a new DStream by applying 'full outer join' between RDDs of 
this DStream and
         `other` DStream.
@@ -432,13 +575,14 @@ class DStream:
             numPartitions = self._sc.defaultParallelism
         return self.transformWith(lambda a, b: a.fullOuterJoin(b, 
numPartitions), other)
 
-    def _jtime(self, timestamp):
+    def _jtime(self, timestamp: Union[datetime, int, float]) -> JavaObject:
         """Convert datetime or unix_timestamp into Time"""
         if isinstance(timestamp, datetime):
             timestamp = time.mktime(timestamp.timetuple())
+        assert self._sc._jvm is not None
         return self._sc._jvm.Time(int(timestamp * 1000))
 
-    def slice(self, begin, end):
+    def slice(self, begin: Union[datetime, int], end: Union[datetime, int]) -> 
List[RDD[T]]:
         """
         Return all the RDDs between 'begin' to 'end' (both included)
 
@@ -447,7 +591,7 @@ class DStream:
         jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
         return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
 
-    def _validate_window_param(self, window, slide):
+    def _validate_window_param(self, window: int, slide: Optional[int]) -> 
None:
         duration = self._jdstream.dstream().slideDuration().milliseconds()
         if int(window * 1000) % duration != 0:
             raise ValueError(
@@ -460,7 +604,7 @@ class DStream:
                 "dstream's slide (batch) duration (%d ms)" % duration
             )
 
-    def window(self, windowDuration, slideDuration=None):
+    def window(self, windowDuration: int, slideDuration: Optional[int] = None) 
-> "DStream[T]":
         """
         Return a new DStream in which each RDD contains all the elements in 
seen in a
         sliding window of time over this DStream.
@@ -482,7 +626,13 @@ class DStream:
         s = self._ssc._jduration(slideDuration)
         return DStream(self._jdstream.window(d, s), self._ssc, 
self._jrdd_deserializer)
 
-    def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, 
slideDuration):
+    def reduceByWindow(
+        self: "DStream[T]",
+        reduceFunc: Callable[[T, T], T],
+        invReduceFunc: Optional[Callable[[T, T], T]],
+        windowDuration: int,
+        slideDuration: int,
+    ) -> "DStream[T]":
         """
         Return a new DStream in which each RDD has a single element generated 
by reducing all
         elements in a sliding window over this DStream.
@@ -517,7 +667,9 @@ class DStream:
         )
         return reduced.map(lambda kv: kv[1])
 
-    def countByWindow(self, windowDuration, slideDuration):
+    def countByWindow(
+        self: "DStream[T]", windowDuration: int, slideDuration: int
+    ) -> "DStream[int]":
         """
         Return a new DStream in which each RDD has a single element generated
         by counting the number of elements in a window over this DStream.
@@ -530,7 +682,12 @@ class DStream:
             operator.add, operator.sub, windowDuration, slideDuration
         )
 
-    def countByValueAndWindow(self, windowDuration, slideDuration, 
numPartitions=None):
+    def countByValueAndWindow(
+        self: "DStream[T]",
+        windowDuration: int,
+        slideDuration: int,
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[T, int]]":
         """
         Return a new DStream in which each RDD contains the count of distinct 
elements in
         RDDs in a sliding window over this DStream.
@@ -553,7 +710,12 @@ class DStream:
         )
         return counted.filter(lambda kv: kv[1] > 0)
 
-    def groupByKeyAndWindow(self, windowDuration, slideDuration, 
numPartitions=None):
+    def groupByKeyAndWindow(
+        self: "DStream[Tuple[K, V]]",
+        windowDuration: int,
+        slideDuration: int,
+        numPartitions: Optional[int] = None,
+    ) -> "DStream[Tuple[K, Iterable[V]]]":
         """
         Return a new DStream by applying `groupByKey` over a sliding window.
         Similar to `DStream.groupByKey()`, but applies it over a sliding 
window.
@@ -572,7 +734,7 @@ class DStream:
         """
         ls = self.mapValues(lambda x: [x])
         grouped = ls.reduceByKeyAndWindow(
-            lambda a, b: a.extend(b) or a,
+            lambda a, b: a.extend(b) or a,  # type: ignore[func-returns-value]
             lambda a, b: a[len(b) :],
             windowDuration,
             slideDuration,
@@ -581,8 +743,14 @@ class DStream:
         return grouped.mapValues(ResultIterable)
 
     def reduceByKeyAndWindow(
-        self, func, invFunc, windowDuration, slideDuration=None, 
numPartitions=None, filterFunc=None
-    ):
+        self: "DStream[Tuple[K, V]]",
+        func: Callable[[V, V], V],
+        invFunc: Optional[Callable[[V, V], V]],
+        windowDuration: int,
+        slideDuration: Optional[int] = None,
+        numPartitions: Optional[int] = None,
+        filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = None,
+    ) -> "DStream[Tuple[K, V]]":
         """
         Return a new DStream by applying incremental `reduceByKey` over a 
sliding window.
 
@@ -621,36 +789,46 @@ class DStream:
 
         if invFunc:
 
-            def reduceFunc(t, a, b):
+            def reduceFunc(t: datetime, a: Any, b: Any) -> Any:
                 b = b.reduceByKey(func, numPartitions)
                 r = a.union(b).reduceByKey(func, numPartitions) if a else b
                 if filterFunc:
                     r = r.filter(filterFunc)
                 return r
 
-            def invReduceFunc(t, a, b):
+            def invReduceFunc(t: datetime, a: Any, b: Any) -> Any:
                 b = b.reduceByKey(func, numPartitions)
                 joined = a.leftOuterJoin(b, numPartitions)
                 return joined.mapValues(
-                    lambda kv: invFunc(kv[0], kv[1]) if kv[1] is not None else 
kv[0]
+                    lambda kv: invFunc(kv[0], kv[1])  # type: ignore[misc]
+                    if kv[1] is not None
+                    else kv[0]
                 )
 
             jreduceFunc = TransformFunction(self._sc, reduceFunc, 
reduced._jrdd_deserializer)
             jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, 
reduced._jrdd_deserializer)
             if slideDuration is None:
                 slideDuration = self._slideDuration
+            assert self._sc._jvm is not None
             dstream = self._sc._jvm.PythonReducedWindowedDStream(
                 reduced._jdstream.dstream(),
                 jreduceFunc,
                 jinvReduceFunc,
                 self._ssc._jduration(windowDuration),
-                self._ssc._jduration(slideDuration),
+                self._ssc._jduration(slideDuration),  # type: ignore[arg-type]
             )
             return DStream(dstream.asJavaDStream(), self._ssc, 
self._sc.serializer)
         else:
-            return reduced.window(windowDuration, 
slideDuration).reduceByKey(func, numPartitions)
+            return reduced.window(windowDuration, slideDuration).reduceByKey(
+                func, numPartitions  # type: ignore[arg-type]
+            )
 
-    def updateStateByKey(self, updateFunc, numPartitions=None, 
initialRDD=None):
+    def updateStateByKey(
+        self: "DStream[Tuple[K, V]]",
+        updateFunc: Callable[[Iterable[V], Optional[S]], S],
+        numPartitions: Optional[int] = None,
+        initialRDD: Optional[Union[RDD[Tuple[K, S]], Iterable[Tuple[K, S]]]] = 
None,
+    ) -> "DStream[Tuple[K, S]]":
         """
         Return a new "state" DStream where the state for each key is updated 
by applying
         the given function on the previous state of the key and the new values 
of the key.
@@ -667,30 +845,37 @@ class DStream:
         if initialRDD and not isinstance(initialRDD, RDD):
             initialRDD = self._sc.parallelize(initialRDD)
 
-        def reduceFunc(t, a, b):
+        def reduceFunc(t: datetime, a: Any, b: Any) -> Any:
             if a is None:
                 g = b.groupByKey(numPartitions).mapValues(lambda vs: 
(list(vs), None))
             else:
-                g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
+                g = a.cogroup(b.partitionBy(cast(int, numPartitions)), 
numPartitions)
                 g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if 
len(ab[0]) else None))
             state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1]))
             return state.filter(lambda k_v: k_v[1] is not None)
 
         jreduceFunc = TransformFunction(
-            self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer
+            self._sc,
+            reduceFunc,
+            self._sc.serializer,
+            self._jrdd_deserializer,
         )
         if initialRDD:
-            initialRDD = initialRDD._reserialize(self._jrdd_deserializer)
+            initialRDD = cast(RDD[Tuple[K, S]], 
initialRDD)._reserialize(self._jrdd_deserializer)
+            assert self._sc._jvm is not None
             dstream = self._sc._jvm.PythonStateDStream(
-                self._jdstream.dstream(), jreduceFunc, initialRDD._jrdd
+                self._jdstream.dstream(),
+                jreduceFunc,
+                initialRDD._jrdd,
             )
         else:
+            assert self._sc._jvm is not None
             dstream = 
self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
 
         return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
 
 
-class TransformedDStream(DStream):
+class TransformedDStream(DStream[U]):
     """
     TransformedDStream is a DStream generated by an Python function
     transforming each RDD of a DStream to another RDDs.
@@ -699,7 +884,23 @@ class TransformedDStream(DStream):
     one transformation.
     """
 
-    def __init__(self, prev, func):
+    @overload
+    def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]], 
RDD[U]]):
+        ...
+
+    @overload
+    def __init__(
+        self: DStream[U],
+        prev: DStream[T],
+        func: Callable[[datetime, RDD[T]], RDD[U]],
+    ):
+        ...
+
+    def __init__(
+        self,
+        prev: DStream[T],
+        func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], 
RDD[U]]],
+    ):
         self._ssc = prev._ssc
         self._sc = self._ssc._sc
         self._jrdd_deserializer = self._sc.serializer
@@ -710,19 +911,23 @@ class TransformedDStream(DStream):
         # Using type() to avoid folding the functions and compacting the 
DStreams which is not
         # not strictly an object of TransformedDStream.
         if type(prev) is TransformedDStream and not prev.is_cached and not 
prev.is_checkpointed:
-            prev_func = prev.func
-            self.func = lambda t, rdd: func(t, prev_func(t, rdd))
-            self.prev = prev.prev
+            prev_func: Callable = prev.func
+            func = cast(Callable[[datetime, RDD[T]], RDD[U]], func)
+            self.func: Union[
+                Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], 
RDD[U]]
+            ] = lambda t, rdd: func(t, prev_func(t, rdd))
+            self.prev: DStream[T] = prev.prev
         else:
             self.prev = prev
             self.func = func
 
     @property
-    def _jdstream(self):
+    def _jdstream(self) -> JavaObject:
         if self._jdstream_val is not None:
             return self._jdstream_val
 
         jfunc = TransformFunction(self._sc, self.func, 
self.prev._jrdd_deserializer)
+        assert self._sc._jvm is not None
         dstream = 
self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
         self._jdstream_val = dstream.asJavaDStream()
         return self._jdstream_val
diff --git a/python/pyspark/streaming/dstream.pyi 
b/python/pyspark/streaming/dstream.pyi
deleted file mode 100644
index c9f31b37f04..00000000000
--- a/python/pyspark/streaming/dstream.pyi
+++ /dev/null
@@ -1,211 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
-    Callable,
-    Generic,
-    Hashable,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    TypeVar,
-    Union,
-)
-import datetime
-from pyspark.rdd import RDD
-import pyspark.serializers
-from pyspark.storagelevel import StorageLevel
-import pyspark.streaming.context
-
-from py4j.java_gateway import JavaObject
-
-S = TypeVar("S")
-T = TypeVar("T")
-T_co = TypeVar("T_co", covariant=True)
-U = TypeVar("U")
-K = TypeVar("K", bound=Hashable)
-V = TypeVar("V")
-
-class DStream(Generic[T_co]):
-    is_cached: bool
-    is_checkpointed: bool
-    def __init__(
-        self,
-        jdstream: JavaObject,
-        ssc: pyspark.streaming.context.StreamingContext,
-        jrdd_deserializer: pyspark.serializers.Serializer,
-    ) -> None: ...
-    def context(self) -> pyspark.streaming.context.StreamingContext: ...
-    def count(self) -> DStream[int]: ...
-    def filter(self, f: Callable[[T_co], bool]) -> DStream[T_co]: ...
-    def flatMap(
-        self: DStream[T_co],
-        f: Callable[[T_co], Iterable[U]],
-        preservesPartitioning: bool = ...,
-    ) -> DStream[U]: ...
-    def map(
-        self: DStream[T_co], f: Callable[[T_co], U], preservesPartitioning: 
bool = ...
-    ) -> DStream[U]: ...
-    def mapPartitions(
-        self, f: Callable[[Iterable[T_co]], Iterable[U]], 
preservesPartitioning: bool = ...
-    ) -> DStream[U]: ...
-    def mapPartitionsWithIndex(
-        self,
-        f: Callable[[int, Iterable[T_co]], Iterable[U]],
-        preservesPartitioning: bool = ...,
-    ) -> DStream[U]: ...
-    def reduce(self, func: Callable[[T_co, T_co], T_co]) -> DStream[T_co]: ...
-    def reduceByKey(
-        self: DStream[Tuple[K, V]],
-        func: Callable[[V, V], V],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, V]]: ...
-    def combineByKey(
-        self: DStream[Tuple[K, V]],
-        createCombiner: Callable[[V], U],
-        mergeValue: Callable[[U, V], U],
-        mergeCombiners: Callable[[U, U], U],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, U]]: ...
-    def partitionBy(
-        self: DStream[Tuple[K, V]],
-        numPartitions: int,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> DStream[Tuple[K, V]]: ...
-    @overload
-    def foreachRDD(self, func: Callable[[RDD[T_co]], None]) -> None: ...
-    @overload
-    def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T_co]], None]) 
-> None: ...
-    def pprint(self, num: int = ...) -> None: ...
-    def mapValues(self: DStream[Tuple[K, V]], f: Callable[[V], U]) -> 
DStream[Tuple[K, U]]: ...
-    def flatMapValues(
-        self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]]
-    ) -> DStream[Tuple[K, U]]: ...
-    def glom(self) -> DStream[List[T_co]]: ...
-    def cache(self) -> DStream[T_co]: ...
-    def persist(self, storageLevel: StorageLevel) -> DStream[T_co]: ...
-    def checkpoint(self, interval: int) -> DStream[T_co]: ...
-    def groupByKey(
-        self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ...
-    ) -> DStream[Tuple[K, Iterable[V]]]: ...
-    def countByValue(self) -> DStream[Tuple[T_co, int]]: ...
-    def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) -> 
None: ...
-    @overload
-    def transform(self, func: Callable[[RDD[T_co]], RDD[U]]) -> 
TransformedDStream[U]: ...
-    @overload
-    def transform(
-        self, func: Callable[[datetime.datetime, RDD[T_co]], RDD[U]]
-    ) -> TransformedDStream[U]: ...
-    @overload
-    def transformWith(
-        self,
-        func: Callable[[RDD[T_co], RDD[U]], RDD[V]],
-        other: RDD[U],
-        keepSerializer: bool = ...,
-    ) -> DStream[V]: ...
-    @overload
-    def transformWith(
-        self,
-        func: Callable[[datetime.datetime, RDD[T_co], RDD[U]], RDD[V]],
-        other: RDD[U],
-        keepSerializer: bool = ...,
-    ) -> DStream[V]: ...
-    def repartition(self, numPartitions: int) -> DStream[T_co]: ...
-    def union(self, other: DStream[U]) -> DStream[Union[T_co, U]]: ...
-    def cogroup(
-        self: DStream[Tuple[K, V]],
-        other: DStream[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Tuple[List[V], List[U]]]]: ...
-    def join(
-        self: DStream[Tuple[K, V]],
-        other: DStream[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Tuple[V, U]]]: ...
-    def leftOuterJoin(
-        self: DStream[Tuple[K, V]],
-        other: DStream[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Tuple[V, Optional[U]]]]: ...
-    def rightOuterJoin(
-        self: DStream[Tuple[K, V]],
-        other: DStream[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Tuple[Optional[V], U]]]: ...
-    def fullOuterJoin(
-        self: DStream[Tuple[K, V]],
-        other: DStream[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ...
-    def slice(
-        self, begin: Union[datetime.datetime, int], end: 
Union[datetime.datetime, int]
-    ) -> List[RDD[T_co]]: ...
-    def window(self, windowDuration: int, slideDuration: Optional[int] = ...) 
-> DStream[T_co]: ...
-    def reduceByWindow(
-        self,
-        reduceFunc: Callable[[T_co, T_co], T_co],
-        invReduceFunc: Optional[Callable[[T_co, T_co], T_co]],
-        windowDuration: int,
-        slideDuration: int,
-    ) -> DStream[T_co]: ...
-    def countByWindow(
-        self, windowDuration: int, slideDuration: int
-    ) -> DStream[Tuple[T_co, int]]: ...
-    def countByValueAndWindow(
-        self,
-        windowDuration: int,
-        slideDuration: int,
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[T_co, int]]: ...
-    def groupByKeyAndWindow(
-        self: DStream[Tuple[K, V]],
-        windowDuration: int,
-        slideDuration: int,
-        numPartitions: Optional[int] = ...,
-    ) -> DStream[Tuple[K, Iterable[V]]]: ...
-    def reduceByKeyAndWindow(
-        self: DStream[Tuple[K, V]],
-        func: Callable[[V, V], V],
-        invFunc: Optional[Callable[[V, V], V]],
-        windowDuration: int,
-        slideDuration: Optional[int] = ...,
-        numPartitions: Optional[int] = ...,
-        filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = ...,
-    ) -> DStream[Tuple[K, V]]: ...
-    def updateStateByKey(
-        self: DStream[Tuple[K, V]],
-        updateFunc: Callable[[Iterable[V], Optional[S]], S],
-        numPartitions: Optional[int] = ...,
-        initialRDD: Optional[RDD[Tuple[K, S]]] = ...,
-    ) -> DStream[Tuple[K, S]]: ...
-
-class TransformedDStream(DStream[U]):
-    is_cached: bool
-    is_checkpointed: bool
-    func: Callable
-    prev: DStream
-    @overload
-    def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]], 
RDD[U]]) -> None: ...
-    @overload
-    def __init__(
-        self: DStream[U],
-        prev: DStream[T],
-        func: Callable[[datetime.datetime, RDD[T]], RDD[U]],
-    ) -> None: ...


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to