This is an automated email from the ASF dual-hosted git repository. zero323 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 5544cce1588 [SPARK-37015][PYTHON] Inline type hints for python/pyspark/streaming/dstream.py 5544cce1588 is described below commit 5544cce15885b1f12ae5826cd3bd2d151e1d544a Author: dch nguyen <dchvn....@gmail.com> AuthorDate: Mon Apr 18 17:38:32 2022 +0200 [SPARK-37015][PYTHON] Inline type hints for python/pyspark/streaming/dstream.py ### What changes were proposed in this pull request? Inline type hints for python/pyspark/streaming/dstream.py ### Why are the changes needed? We can take advantage of static type checking within the functions by inlining the type hints. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #34324 from dchvn/SPARK-37015. Lead-authored-by: dch nguyen <dchvn....@gmail.com> Co-authored-by: dch nguyen <dgd_contribu...@viettel.com.vn> Signed-off-by: zero323 <mszymkiew...@gmail.com> (cherry picked from commit dff52d649d1e27baf3b107f75636624e0cfe780f) Signed-off-by: zero323 <mszymkiew...@gmail.com> --- python/pyspark/streaming/context.py | 22 +-- python/pyspark/streaming/dstream.py | 369 +++++++++++++++++++++++++++-------- python/pyspark/streaming/dstream.pyi | 211 -------------------- 3 files changed, 296 insertions(+), 306 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 52e5efed063..0be0c7b034a 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -397,12 +397,12 @@ class StreamingContext: the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = [d._jdstream for d in dstreams] # type: ignore[attr-defined] + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction( self._sc, lambda t, *rdds: transformFunc(rdds), - *[d._jrdd_deserializer for d in dstreams], # type: ignore[attr-defined] + *[d._jrdd_deserializer for d in dstreams], ) assert self._jvm is not None @@ -419,35 +419,31 @@ class StreamingContext: raise ValueError("should have at least one DStream to union") if len(dstreams) == 1: return dstreams[0] - if len(set(s._jrdd_deserializer for s in dstreams)) > 1: # type: ignore[attr-defined] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: raise ValueError("All DStreams should have same serializer") - if len(set(s._slideDuration for s in dstreams)) > 1: # type: ignore[attr-defined] + if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") assert SparkContext._jvm is not None jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream gw = SparkContext._gateway - if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): # type: ignore[attr-defined] + if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): cls = jdstream_cls - elif is_instance_of( - gw, dstreams[0]._jdstream, jpair_dstream_cls # type: ignore[attr-defined] - ): + elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls): cls = jpair_dstream_cls else: - cls_name = ( - dstreams[0]._jdstream.getClass().getCanonicalName() # type: ignore[attr-defined] - ) + cls_name = dstreams[0]._jdstream.getClass().getCanonicalName() raise TypeError("Unsupported Java DStream class %s" % cls_name) assert gw is not None jdstreams = gw.new_array(cls, len(dstreams)) for i in range(0, len(dstreams)): - jdstreams[i] = dstreams[i]._jdstream # type: ignore[attr-defined] + jdstreams[i] = dstreams[i]._jdstream return DStream( self._jssc.union(jdstreams), self, - dstreams[0]._jrdd_deserializer, # type: ignore[attr-defined] + dstreams[0]._jrdd_deserializer, ) def addStreamingListener(self, streamingListener: StreamingListener) -> None: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index f445a78bd95..934b3ae5783 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -19,19 +19,45 @@ import operator import time from itertools import chain from datetime import datetime +from typing import ( + Any, + Callable, + Generic, + Hashable, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, + TYPE_CHECKING, + cast, + overload, +) from py4j.protocol import Py4JJavaError -from pyspark import RDD from pyspark.storagelevel import StorageLevel from pyspark.streaming.util import rddToFileName, TransformFunction -from pyspark.rdd import portable_hash +from pyspark.rdd import portable_hash, RDD from pyspark.resultiterable import ResultIterable +from py4j.java_gateway import JavaObject + +if TYPE_CHECKING: + from pyspark.serializers import Serializer + from pyspark.streaming.context import StreamingContext __all__ = ["DStream"] +S = TypeVar("S") +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +U = TypeVar("U") +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") + -class DStream: +class DStream(Generic[T_co]): """ A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous sequence of RDDs (of the same type) representing a @@ -51,7 +77,12 @@ class DStream: - A function that is used to generate an RDD after each time interval """ - def __init__(self, jdstream, ssc, jrdd_deserializer): + def __init__( + self, + jdstream: JavaObject, + ssc: "StreamingContext", + jrdd_deserializer: "Serializer", + ): self._jdstream = jdstream self._ssc = ssc self._sc = ssc._sc @@ -59,76 +90,94 @@ class DStream: self.is_cached = False self.is_checkpointed = False - def context(self): + def context(self) -> "StreamingContext": """ Return the StreamingContext associated with this DStream """ return self._ssc - def count(self): + def count(self) -> "DStream[int]": """ Return a new DStream in which each RDD has a single element generated by counting each RDD of this DStream. """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) - def filter(self, f): + def filter(self: "DStream[T]", f: Callable[[T], bool]) -> "DStream[T]": """ Return a new DStream containing only the elements that satisfy predicate. """ - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[T]: return filter(f, iterator) return self.mapPartitions(func, True) - def flatMap(self, f, preservesPartitioning=False): + def flatMap( + self: "DStream[T]", + f: Callable[[T], Iterable[U]], + preservesPartitioning: bool = False, + ) -> "DStream[U]": """ Return a new DStream by applying a function to all elements of this DStream, and then flattening the results """ - def func(s, iterator): + def func(s: int, iterator: Iterable[T]) -> Iterable[U]: return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def map(self, f, preservesPartitioning=False): + def map( + self: "DStream[T]", f: Callable[[T], U], preservesPartitioning: bool = False + ) -> "DStream[U]": """ Return a new DStream by applying a function to each element of DStream. """ - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[U]: return map(f, iterator) return self.mapPartitions(func, preservesPartitioning) - def mapPartitions(self, f, preservesPartitioning=False): + def mapPartitions( + self: "DStream[T]", + f: Callable[[Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + ) -> "DStream[U]": """ Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs of this DStream. """ - def func(s, iterator): + def func(s: int, iterator: Iterable[T]) -> Iterable[U]: return f(iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def mapPartitionsWithIndex( + self: "DStream[T]", + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = False, + ) -> "DStream[U]": """ Return a new DStream in which each RDD is generated by applying mapPartitionsWithIndex() to each RDDs of this DStream. """ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) - def reduce(self, func): + def reduce(self: "DStream[T]", func: Callable[[T, T], T]) -> "DStream[T]": """ Return a new DStream in which each RDD has a single element generated by reducing each RDD of this DStream. """ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey( + self: "DStream[Tuple[K, V]]", + func: Callable[[V, V], V], + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, V]]": """ Return a new DStream by applying reduceByKey to each RDD. """ @@ -136,40 +185,62 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.combineByKey(lambda x: x, func, func, numPartitions) - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions=None): + def combineByKey( + self: "DStream[Tuple[K, V]]", + createCombiner: Callable[[V], U], + mergeValue: Callable[[U, V], U], + mergeCombiners: Callable[[U, U], U], + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, U]]": """ Return a new DStream by applying combineByKey to each RDD. """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - def func(rdd): + def func(rdd: RDD[Tuple[K, V]]) -> RDD[Tuple[K, U]]: return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) return self.transform(func) - def partitionBy(self, numPartitions, partitionFunc=portable_hash): + def partitionBy( + self: "DStream[Tuple[K, V]]", + numPartitions: int, + partitionFunc: Callable[[K], int] = portable_hash, + ) -> "DStream[Tuple[K, V]]": """ Return a copy of the DStream in which each RDD are partitioned using the specified partitioner. """ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) - def foreachRDD(self, func): + @overload + def foreachRDD(self: "DStream[T]", func: Callable[[RDD[T]], None]) -> None: + ... + + @overload + def foreachRDD(self: "DStream[T]", func: Callable[[datetime, RDD[T]], None]) -> None: + ... + + def foreachRDD( + self: "DStream[T]", + func: Union[Callable[[RDD[T]], None], Callable[[datetime, RDD[T]], None]], + ) -> None: """ Apply a function to each RDD in this DStream. """ if func.__code__.co_argcount == 1: old_func = func - def func(_, rdd): - return old_func(rdd) + def func(_: datetime, rdd: "RDD[T]") -> None: + return old_func(rdd) # type: ignore[call-arg, arg-type] jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + assert self._ssc._jvm is not None api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) - def pprint(self, num=10): + def pprint(self, num: int = 10) -> None: """ Print the first num elements of each RDD generated in this DStream. @@ -179,7 +250,7 @@ class DStream: the number of elements from the first will be printed. """ - def takeAndPrint(time, rdd): + def takeAndPrint(time: datetime, rdd: RDD[T]) -> None: taken = rdd.take(num + 1) print("-------------------------------------------") print("Time: %s" % time) @@ -192,40 +263,42 @@ class DStream: self.foreachRDD(takeAndPrint) - def mapValues(self, f): + def mapValues(self: "DStream[Tuple[K, V]]", f: Callable[[V], U]) -> "DStream[Tuple[K, U]]": """ Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - def map_values_fn(kv): + def map_values_fn(kv: Tuple[K, V]) -> Tuple[K, U]: return kv[0], f(kv[1]) return self.map(map_values_fn, preservesPartitioning=True) - def flatMapValues(self, f): + def flatMapValues( + self: "DStream[Tuple[K, V]]", f: Callable[[V], Iterable[U]] + ) -> "DStream[Tuple[K, U]]": """ Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - def flat_map_fn(kv): + def flat_map_fn(kv: Tuple[K, V]) -> Iterable[Tuple[K, U]]: return ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) - def glom(self): + def glom(self: "DStream[T]") -> "DStream[List[T]]": """ Return a new DStream in which RDD is generated by applying glom() to RDD of this DStream. """ - def func(iterator): + def func(iterator: Iterable[T]) -> Iterable[List[T]]: yield list(iterator) return self.mapPartitions(func) - def cache(self): + def cache(self: "DStream[T]") -> "DStream[T]": """ Persist the RDDs of this DStream with the default storage level (`MEMORY_ONLY`). @@ -234,7 +307,7 @@ class DStream: self.persist(StorageLevel.MEMORY_ONLY) return self - def persist(self, storageLevel): + def persist(self: "DStream[T]", storageLevel: StorageLevel) -> "DStream[T]": """ Persist the RDDs of this DStream with the given storage level """ @@ -243,7 +316,7 @@ class DStream: self._jdstream.persist(javaStorageLevel) return self - def checkpoint(self, interval): + def checkpoint(self: "DStream[T]", interval: int) -> "DStream[T]": """ Enable periodic checkpointing of RDDs of this DStream @@ -257,7 +330,9 @@ class DStream: self._jdstream.checkpoint(self._ssc._jduration(interval)) return self - def groupByKey(self, numPartitions=None): + def groupByKey( + self: "DStream[Tuple[K, V]]", numPartitions: Optional[int] = None + ) -> "DStream[Tuple[K, Iterable[V]]]": """ Return a new DStream by applying groupByKey on each RDD. """ @@ -265,20 +340,20 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) - def countByValue(self): + def countByValue(self: "DStream[K]") -> "DStream[Tuple[K, int]]": """ Return a new DStream in which each RDD contains the counts of each distinct value in each RDD of this DStream. """ return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) - def saveAsTextFiles(self, prefix, suffix=None): + def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = None) -> None: """ Save each RDD in this DStream as at text file, using string representation of elements. """ - def saveAsTextFile(t, rdd): + def saveAsTextFile(t: Optional[datetime], rdd: RDD[T]) -> None: path = rddToFileName(prefix, suffix, t) try: rdd.saveAsTextFile(path) @@ -307,7 +382,20 @@ class DStream: # raise # return self.foreachRDD(saveAsPickleFile) - def transform(self, func): + @overload + def transform(self: "DStream[T]", func: Callable[[RDD[T]], RDD[U]]) -> "TransformedDStream[U]": + ... + + @overload + def transform( + self: "DStream[T]", func: Callable[[datetime, RDD[T]], RDD[U]] + ) -> "TransformedDStream[U]": + ... + + def transform( + self: "DStream[T]", + func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], RDD[U]]], + ) -> "TransformedDStream[U]": """ Return a new DStream in which each RDD is generated by applying a function on each RDD of this DStream. @@ -318,13 +406,39 @@ class DStream: if func.__code__.co_argcount == 1: oldfunc = func - def func(_, rdd): - return oldfunc(rdd) + def func(_: datetime, rdd: RDD[T]) -> RDD[U]: + return oldfunc(rdd) # type: ignore[arg-type, call-arg] assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) - def transformWith(self, func, other, keepSerializer=False): + @overload + def transformWith( + self: "DStream[T]", + func: Callable[[RDD[T], RDD[U]], RDD[V]], + other: "DStream[U]", + keepSerializer: bool = ..., + ) -> "DStream[V]": + ... + + @overload + def transformWith( + self: "DStream[T]", + func: Callable[[datetime, RDD[T], RDD[U]], RDD[V]], + other: "DStream[U]", + keepSerializer: bool = ..., + ) -> "DStream[V]": + ... + + def transformWith( + self: "DStream[T]", + func: Union[ + Callable[[RDD[T], RDD[U]], RDD[V]], + Callable[[datetime, RDD[T], RDD[U]], RDD[V]], + ], + other: "DStream[U]", + keepSerializer: bool = False, + ) -> "DStream[V]": """ Return a new DStream in which each RDD is generated by applying a function on each RDD of this DStream and 'other' DStream. @@ -335,31 +449,37 @@ class DStream: if func.__code__.co_argcount == 2: oldfunc = func - def func(_, a, b): - return oldfunc(a, b) + def func(_: datetime, a: RDD[T], b: RDD[U]) -> RDD[V]: + return oldfunc(a, b) # type: ignore[call-arg, arg-type] assert func.__code__.co_argcount == 3, "func should take two or three arguments" - jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + jfunc = TransformFunction( + self._sc, + func, + self._jrdd_deserializer, + other._jrdd_deserializer, + ) + assert self._sc._jvm is not None dstream = self._sc._jvm.PythonTransformed2DStream( self._jdstream.dstream(), other._jdstream.dstream(), jfunc ) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) - def repartition(self, numPartitions): + def repartition(self: "DStream[T]", numPartitions: int) -> "DStream[T]": """ Return a new DStream with an increased or decreased level of parallelism. """ return self.transform(lambda rdd: rdd.repartition(numPartitions)) @property - def _slideDuration(self): + def _slideDuration(self) -> None: """ Return the slideDuration in seconds of this DStream """ return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 - def union(self, other): + def union(self: "DStream[T]", other: "DStream[U]") -> "DStream[Union[T, U]]": """ Return a new DStream by unifying data of another DStream with this DStream. @@ -373,7 +493,11 @@ class DStream: raise ValueError("the two DStream should have same slide duration") return self.transformWith(lambda a, b: a.union(b), other, True) - def cogroup(self, other, numPartitions=None): + def cogroup( + self: "DStream[Tuple[K, V]]", + other: "DStream[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]": """ Return a new DStream by applying 'cogroup' between RDDs of this DStream and `other` DStream. @@ -382,9 +506,16 @@ class DStream: """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + return self.transformWith( + lambda a, b: a.cogroup(b, numPartitions), + other, + ) - def join(self, other, numPartitions=None): + def join( + self: "DStream[Tuple[K, V]]", + other: "DStream[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Tuple[V, U]]]": """ Return a new DStream by applying 'join' between RDDs of this DStream and `other` DStream. @@ -396,7 +527,11 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.join(b, numPartitions), other) - def leftOuterJoin(self, other, numPartitions=None): + def leftOuterJoin( + self: "DStream[Tuple[K, V]]", + other: "DStream[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Tuple[V, Optional[U]]]]": """ Return a new DStream by applying 'left outer join' between RDDs of this DStream and `other` DStream. @@ -408,7 +543,11 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) - def rightOuterJoin(self, other, numPartitions=None): + def rightOuterJoin( + self: "DStream[Tuple[K, V]]", + other: "DStream[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Tuple[Optional[V], U]]]": """ Return a new DStream by applying 'right outer join' between RDDs of this DStream and `other` DStream. @@ -420,7 +559,11 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) - def fullOuterJoin(self, other, numPartitions=None): + def fullOuterJoin( + self: "DStream[Tuple[K, V]]", + other: "DStream[Tuple[K, U]]", + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]": """ Return a new DStream by applying 'full outer join' between RDDs of this DStream and `other` DStream. @@ -432,13 +575,14 @@ class DStream: numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) - def _jtime(self, timestamp): + def _jtime(self, timestamp: Union[datetime, int, float]) -> JavaObject: """Convert datetime or unix_timestamp into Time""" if isinstance(timestamp, datetime): timestamp = time.mktime(timestamp.timetuple()) + assert self._sc._jvm is not None return self._sc._jvm.Time(int(timestamp * 1000)) - def slice(self, begin, end): + def slice(self, begin: Union[datetime, int], end: Union[datetime, int]) -> List[RDD[T]]: """ Return all the RDDs between 'begin' to 'end' (both included) @@ -447,7 +591,7 @@ class DStream: jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] - def _validate_window_param(self, window, slide): + def _validate_window_param(self, window: int, slide: Optional[int]) -> None: duration = self._jdstream.dstream().slideDuration().milliseconds() if int(window * 1000) % duration != 0: raise ValueError( @@ -460,7 +604,7 @@ class DStream: "dstream's slide (batch) duration (%d ms)" % duration ) - def window(self, windowDuration, slideDuration=None): + def window(self, windowDuration: int, slideDuration: Optional[int] = None) -> "DStream[T]": """ Return a new DStream in which each RDD contains all the elements in seen in a sliding window of time over this DStream. @@ -482,7 +626,13 @@ class DStream: s = self._ssc._jduration(slideDuration) return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) - def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + def reduceByWindow( + self: "DStream[T]", + reduceFunc: Callable[[T, T], T], + invReduceFunc: Optional[Callable[[T, T], T]], + windowDuration: int, + slideDuration: int, + ) -> "DStream[T]": """ Return a new DStream in which each RDD has a single element generated by reducing all elements in a sliding window over this DStream. @@ -517,7 +667,9 @@ class DStream: ) return reduced.map(lambda kv: kv[1]) - def countByWindow(self, windowDuration, slideDuration): + def countByWindow( + self: "DStream[T]", windowDuration: int, slideDuration: int + ) -> "DStream[int]": """ Return a new DStream in which each RDD has a single element generated by counting the number of elements in a window over this DStream. @@ -530,7 +682,12 @@ class DStream: operator.add, operator.sub, windowDuration, slideDuration ) - def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + def countByValueAndWindow( + self: "DStream[T]", + windowDuration: int, + slideDuration: int, + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[T, int]]": """ Return a new DStream in which each RDD contains the count of distinct elements in RDDs in a sliding window over this DStream. @@ -553,7 +710,12 @@ class DStream: ) return counted.filter(lambda kv: kv[1] > 0) - def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + def groupByKeyAndWindow( + self: "DStream[Tuple[K, V]]", + windowDuration: int, + slideDuration: int, + numPartitions: Optional[int] = None, + ) -> "DStream[Tuple[K, Iterable[V]]]": """ Return a new DStream by applying `groupByKey` over a sliding window. Similar to `DStream.groupByKey()`, but applies it over a sliding window. @@ -572,7 +734,7 @@ class DStream: """ ls = self.mapValues(lambda x: [x]) grouped = ls.reduceByKeyAndWindow( - lambda a, b: a.extend(b) or a, + lambda a, b: a.extend(b) or a, # type: ignore[func-returns-value] lambda a, b: a[len(b) :], windowDuration, slideDuration, @@ -581,8 +743,14 @@ class DStream: return grouped.mapValues(ResultIterable) def reduceByKeyAndWindow( - self, func, invFunc, windowDuration, slideDuration=None, numPartitions=None, filterFunc=None - ): + self: "DStream[Tuple[K, V]]", + func: Callable[[V, V], V], + invFunc: Optional[Callable[[V, V], V]], + windowDuration: int, + slideDuration: Optional[int] = None, + numPartitions: Optional[int] = None, + filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = None, + ) -> "DStream[Tuple[K, V]]": """ Return a new DStream by applying incremental `reduceByKey` over a sliding window. @@ -621,36 +789,46 @@ class DStream: if invFunc: - def reduceFunc(t, a, b): + def reduceFunc(t: datetime, a: Any, b: Any) -> Any: b = b.reduceByKey(func, numPartitions) r = a.union(b).reduceByKey(func, numPartitions) if a else b if filterFunc: r = r.filter(filterFunc) return r - def invReduceFunc(t, a, b): + def invReduceFunc(t: datetime, a: Any, b: Any) -> Any: b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues( - lambda kv: invFunc(kv[0], kv[1]) if kv[1] is not None else kv[0] + lambda kv: invFunc(kv[0], kv[1]) # type: ignore[misc] + if kv[1] is not None + else kv[0] ) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) if slideDuration is None: slideDuration = self._slideDuration + assert self._sc._jvm is not None dstream = self._sc._jvm.PythonReducedWindowedDStream( reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, self._ssc._jduration(windowDuration), - self._ssc._jduration(slideDuration), + self._ssc._jduration(slideDuration), # type: ignore[arg-type] ) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) else: - return reduced.window(windowDuration, slideDuration).reduceByKey(func, numPartitions) + return reduced.window(windowDuration, slideDuration).reduceByKey( + func, numPartitions # type: ignore[arg-type] + ) - def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): + def updateStateByKey( + self: "DStream[Tuple[K, V]]", + updateFunc: Callable[[Iterable[V], Optional[S]], S], + numPartitions: Optional[int] = None, + initialRDD: Optional[Union[RDD[Tuple[K, S]], Iterable[Tuple[K, S]]]] = None, + ) -> "DStream[Tuple[K, S]]": """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -667,30 +845,37 @@ class DStream: if initialRDD and not isinstance(initialRDD, RDD): initialRDD = self._sc.parallelize(initialRDD) - def reduceFunc(t, a, b): + def reduceFunc(t: datetime, a: Any, b: Any) -> Any: if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: - g = a.cogroup(b.partitionBy(numPartitions), numPartitions) + g = a.cogroup(b.partitionBy(cast(int, numPartitions)), numPartitions) g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None)) state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1])) return state.filter(lambda k_v: k_v[1] is not None) jreduceFunc = TransformFunction( - self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer + self._sc, + reduceFunc, + self._sc.serializer, + self._jrdd_deserializer, ) if initialRDD: - initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + initialRDD = cast(RDD[Tuple[K, S]], initialRDD)._reserialize(self._jrdd_deserializer) + assert self._sc._jvm is not None dstream = self._sc._jvm.PythonStateDStream( - self._jdstream.dstream(), jreduceFunc, initialRDD._jrdd + self._jdstream.dstream(), + jreduceFunc, + initialRDD._jrdd, ) else: + assert self._sc._jvm is not None dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) -class TransformedDStream(DStream): +class TransformedDStream(DStream[U]): """ TransformedDStream is a DStream generated by an Python function transforming each RDD of a DStream to another RDDs. @@ -699,7 +884,23 @@ class TransformedDStream(DStream): one transformation. """ - def __init__(self, prev, func): + @overload + def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]], RDD[U]]): + ... + + @overload + def __init__( + self: DStream[U], + prev: DStream[T], + func: Callable[[datetime, RDD[T]], RDD[U]], + ): + ... + + def __init__( + self, + prev: DStream[T], + func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], RDD[U]]], + ): self._ssc = prev._ssc self._sc = self._ssc._sc self._jrdd_deserializer = self._sc.serializer @@ -710,19 +911,23 @@ class TransformedDStream(DStream): # Using type() to avoid folding the functions and compacting the DStreams which is not # not strictly an object of TransformedDStream. if type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed: - prev_func = prev.func - self.func = lambda t, rdd: func(t, prev_func(t, rdd)) - self.prev = prev.prev + prev_func: Callable = prev.func + func = cast(Callable[[datetime, RDD[T]], RDD[U]], func) + self.func: Union[ + Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]], RDD[U]] + ] = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev: DStream[T] = prev.prev else: self.prev = prev self.func = func @property - def _jdstream(self): + def _jdstream(self) -> JavaObject: if self._jdstream_val is not None: return self._jdstream_val jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + assert self._sc._jvm is not None dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi deleted file mode 100644 index c9f31b37f04..00000000000 --- a/python/pyspark/streaming/dstream.pyi +++ /dev/null @@ -1,211 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import ( - Callable, - Generic, - Hashable, - Iterable, - List, - Optional, - Tuple, - TypeVar, - Union, -) -import datetime -from pyspark.rdd import RDD -import pyspark.serializers -from pyspark.storagelevel import StorageLevel -import pyspark.streaming.context - -from py4j.java_gateway import JavaObject - -S = TypeVar("S") -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -U = TypeVar("U") -K = TypeVar("K", bound=Hashable) -V = TypeVar("V") - -class DStream(Generic[T_co]): - is_cached: bool - is_checkpointed: bool - def __init__( - self, - jdstream: JavaObject, - ssc: pyspark.streaming.context.StreamingContext, - jrdd_deserializer: pyspark.serializers.Serializer, - ) -> None: ... - def context(self) -> pyspark.streaming.context.StreamingContext: ... - def count(self) -> DStream[int]: ... - def filter(self, f: Callable[[T_co], bool]) -> DStream[T_co]: ... - def flatMap( - self: DStream[T_co], - f: Callable[[T_co], Iterable[U]], - preservesPartitioning: bool = ..., - ) -> DStream[U]: ... - def map( - self: DStream[T_co], f: Callable[[T_co], U], preservesPartitioning: bool = ... - ) -> DStream[U]: ... - def mapPartitions( - self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ... - ) -> DStream[U]: ... - def mapPartitionsWithIndex( - self, - f: Callable[[int, Iterable[T_co]], Iterable[U]], - preservesPartitioning: bool = ..., - ) -> DStream[U]: ... - def reduce(self, func: Callable[[T_co, T_co], T_co]) -> DStream[T_co]: ... - def reduceByKey( - self: DStream[Tuple[K, V]], - func: Callable[[V, V], V], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, V]]: ... - def combineByKey( - self: DStream[Tuple[K, V]], - createCombiner: Callable[[V], U], - mergeValue: Callable[[U, V], U], - mergeCombiners: Callable[[U, U], U], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, U]]: ... - def partitionBy( - self: DStream[Tuple[K, V]], - numPartitions: int, - partitionFunc: Callable[[K], int] = ..., - ) -> DStream[Tuple[K, V]]: ... - @overload - def foreachRDD(self, func: Callable[[RDD[T_co]], None]) -> None: ... - @overload - def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T_co]], None]) -> None: ... - def pprint(self, num: int = ...) -> None: ... - def mapValues(self: DStream[Tuple[K, V]], f: Callable[[V], U]) -> DStream[Tuple[K, U]]: ... - def flatMapValues( - self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]] - ) -> DStream[Tuple[K, U]]: ... - def glom(self) -> DStream[List[T_co]]: ... - def cache(self) -> DStream[T_co]: ... - def persist(self, storageLevel: StorageLevel) -> DStream[T_co]: ... - def checkpoint(self, interval: int) -> DStream[T_co]: ... - def groupByKey( - self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ... - ) -> DStream[Tuple[K, Iterable[V]]]: ... - def countByValue(self) -> DStream[Tuple[T_co, int]]: ... - def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) -> None: ... - @overload - def transform(self, func: Callable[[RDD[T_co]], RDD[U]]) -> TransformedDStream[U]: ... - @overload - def transform( - self, func: Callable[[datetime.datetime, RDD[T_co]], RDD[U]] - ) -> TransformedDStream[U]: ... - @overload - def transformWith( - self, - func: Callable[[RDD[T_co], RDD[U]], RDD[V]], - other: RDD[U], - keepSerializer: bool = ..., - ) -> DStream[V]: ... - @overload - def transformWith( - self, - func: Callable[[datetime.datetime, RDD[T_co], RDD[U]], RDD[V]], - other: RDD[U], - keepSerializer: bool = ..., - ) -> DStream[V]: ... - def repartition(self, numPartitions: int) -> DStream[T_co]: ... - def union(self, other: DStream[U]) -> DStream[Union[T_co, U]]: ... - def cogroup( - self: DStream[Tuple[K, V]], - other: DStream[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Tuple[List[V], List[U]]]]: ... - def join( - self: DStream[Tuple[K, V]], - other: DStream[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Tuple[V, U]]]: ... - def leftOuterJoin( - self: DStream[Tuple[K, V]], - other: DStream[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Tuple[V, Optional[U]]]]: ... - def rightOuterJoin( - self: DStream[Tuple[K, V]], - other: DStream[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Tuple[Optional[V], U]]]: ... - def fullOuterJoin( - self: DStream[Tuple[K, V]], - other: DStream[Tuple[K, U]], - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... - def slice( - self, begin: Union[datetime.datetime, int], end: Union[datetime.datetime, int] - ) -> List[RDD[T_co]]: ... - def window(self, windowDuration: int, slideDuration: Optional[int] = ...) -> DStream[T_co]: ... - def reduceByWindow( - self, - reduceFunc: Callable[[T_co, T_co], T_co], - invReduceFunc: Optional[Callable[[T_co, T_co], T_co]], - windowDuration: int, - slideDuration: int, - ) -> DStream[T_co]: ... - def countByWindow( - self, windowDuration: int, slideDuration: int - ) -> DStream[Tuple[T_co, int]]: ... - def countByValueAndWindow( - self, - windowDuration: int, - slideDuration: int, - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[T_co, int]]: ... - def groupByKeyAndWindow( - self: DStream[Tuple[K, V]], - windowDuration: int, - slideDuration: int, - numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[K, Iterable[V]]]: ... - def reduceByKeyAndWindow( - self: DStream[Tuple[K, V]], - func: Callable[[V, V], V], - invFunc: Optional[Callable[[V, V], V]], - windowDuration: int, - slideDuration: Optional[int] = ..., - numPartitions: Optional[int] = ..., - filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = ..., - ) -> DStream[Tuple[K, V]]: ... - def updateStateByKey( - self: DStream[Tuple[K, V]], - updateFunc: Callable[[Iterable[V], Optional[S]], S], - numPartitions: Optional[int] = ..., - initialRDD: Optional[RDD[Tuple[K, S]]] = ..., - ) -> DStream[Tuple[K, S]]: ... - -class TransformedDStream(DStream[U]): - is_cached: bool - is_checkpointed: bool - func: Callable - prev: DStream - @overload - def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]], RDD[U]]) -> None: ... - @overload - def __init__( - self: DStream[U], - prev: DStream[T], - func: Callable[[datetime.datetime, RDD[T]], RDD[U]], - ) -> None: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org