This is an automated email from the ASF dual-hosted git repository. zero323 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c0c1f35cd92 [SPARK-37014][PYTHON] Inline type hints for python/pyspark/streaming/context.py c0c1f35cd92 is described below commit c0c1f35cd9279bc1a7a50119be72a297162a9b55 Author: dch nguyen <dchvn....@gmail.com> AuthorDate: Thu Apr 14 02:03:24 2022 +0200 [SPARK-37014][PYTHON] Inline type hints for python/pyspark/streaming/context.py ### What changes were proposed in this pull request? Inline type hints for python/pyspark/streaming/context.py from Inline type hints for python/pyspark/streaming/context.pyi. ### Why are the changes needed? Currently, there is type hint stub files python/pyspark/streaming/context.pyi to show the expected types for functions, but we can also take advantage of static type checking within the functions by inlining the type hints. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test. Closes #34293 from dchvn/SPARK-37014. Authored-by: dch nguyen <dchvn....@gmail.com> Signed-off-by: zero323 <mszymkiew...@gmail.com> --- python/pyspark/streaming/context.py | 123 ++++++++++++++++++++++++----------- python/pyspark/streaming/context.pyi | 71 -------------------- python/pyspark/streaming/kinesis.py | 9 +-- 3 files changed, 91 insertions(+), 112 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index cc9875d6575..52e5efed063 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,18 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Any, Callable, List, Optional, TypeVar -from py4j.java_gateway import java_import, is_instance_of +from py4j.java_gateway import java_import, is_instance_of, JavaObject from pyspark import RDD, SparkConf from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream +from pyspark.streaming.listener import StreamingListener from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer __all__ = ["StreamingContext"] +T = TypeVar("T") + class StreamingContext: """ @@ -51,27 +55,35 @@ class StreamingContext: # Reference to a currently active StreamingContext _activeContext = None - def __init__(self, sparkContext, batchDuration=None, jssc=None): - + def __init__( + self, + sparkContext: SparkContext, + batchDuration: Optional[int] = None, + jssc: Optional[JavaObject] = None, + ): self._sc = sparkContext self._jvm = self._sc._jvm self._jssc = jssc or self._initialize_context(self._sc, batchDuration) - def _initialize_context(self, sc, duration): + def _initialize_context(self, sc: SparkContext, duration: Optional[int]) -> JavaObject: self._ensure_initialized() + assert self._jvm is not None and duration is not None return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) - def _jduration(self, seconds): + def _jduration(self, seconds: int) -> JavaObject: """ Create Duration object given number of seconds """ + assert self._jvm is not None return self._jvm.Duration(int(seconds * 1000)) @classmethod - def _ensure_initialized(cls): + def _ensure_initialized(cls) -> None: SparkContext._ensure_initialized() gw = SparkContext._gateway + assert gw is not None + java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") @@ -83,11 +95,15 @@ class StreamingContext: # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( - SparkContext._active_spark_context, CloudPickleSerializer(), gw + SparkContext._active_spark_context, + CloudPickleSerializer(), + gw, ) @classmethod - def getOrCreate(cls, checkpointPath, setupFunc): + def getOrCreate( + cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"] + ) -> "StreamingContext": """ Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be @@ -104,6 +120,8 @@ class StreamingContext: cls._ensure_initialized() gw = SparkContext._gateway + assert gw is not None + # Check whether valid checkpoint information exists in the given path ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath) if ssc_option.isEmpty(): @@ -121,12 +139,15 @@ class StreamingContext: sc = SparkContext._active_spark_context + assert sc is not None + # update ctx in serializer + assert cls._transformerSerializer is not None cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) @classmethod - def getActive(cls): + def getActive(cls) -> Optional["StreamingContext"]: """ Return either the currently active StreamingContext (i.e., if there is a context started but not stopped) or None. @@ -149,7 +170,9 @@ class StreamingContext: return cls._activeContext @classmethod - def getActiveOrCreate(cls, checkpointPath, setupFunc): + def getActiveOrCreate( + cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"] + ) -> "StreamingContext": """ Either return the active StreamingContext (i.e. currently started but not stopped), or recreate a StreamingContext from checkpoint data or create a new StreamingContext @@ -178,20 +201,20 @@ class StreamingContext: return setupFunc() @property - def sparkContext(self): + def sparkContext(self) -> SparkContext: """ Return SparkContext which is associated with this StreamingContext. """ return self._sc - def start(self): + def start(self) -> None: """ Start the execution of the streams. """ self._jssc.start() StreamingContext._activeContext = self - def awaitTermination(self, timeout=None): + def awaitTermination(self, timeout: Optional[int] = None) -> None: """ Wait for the execution to stop. @@ -205,7 +228,7 @@ class StreamingContext: else: self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) - def awaitTerminationOrTimeout(self, timeout): + def awaitTerminationOrTimeout(self, timeout: int) -> None: """ Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the @@ -218,7 +241,7 @@ class StreamingContext: """ return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) - def stop(self, stopSparkContext=True, stopGraceFully=False): + def stop(self, stopSparkContext: bool = True, stopGraceFully: bool = False) -> None: """ Stop the execution of the streams, with option of ensuring all received data has been processed. @@ -236,7 +259,7 @@ class StreamingContext: if stopSparkContext: self._sc.stop() - def remember(self, duration): + def remember(self, duration: int) -> None: """ Set each DStreams in this context to remember RDDs it generated in the last given duration. DStreams remember RDDs only for a @@ -252,7 +275,7 @@ class StreamingContext: """ self._jssc.remember(self._jduration(duration)) - def checkpoint(self, directory): + def checkpoint(self, directory: str) -> None: """ Sets the context to periodically checkpoint the DStream operations for master fault-tolerance. The graph will be checkpointed every batch interval. @@ -264,7 +287,9 @@ class StreamingContext: """ self._jssc.checkpoint(directory) - def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2): + def socketTextStream( + self, hostname: str, port: int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 + ) -> "DStream[str]": """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited @@ -284,7 +309,7 @@ class StreamingContext: self._jssc.socketTextStream(hostname, port, jlevel), self, UTF8Deserializer() ) - def textFileStream(self, directory): + def textFileStream(self, directory: str) -> "DStream[str]": """ Create an input stream that monitors a Hadoop-compatible file system for new files and reads them as text files. Files must be written to the @@ -294,7 +319,7 @@ class StreamingContext: """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def binaryRecordsStream(self, directory, recordLength): + def binaryRecordsStream(self, directory: str, recordLength: int) -> "DStream[bytes]": """ Create an input stream that monitors a Hadoop-compatible file system for new files and reads them as flat binary files with records of @@ -313,14 +338,19 @@ class StreamingContext: self._jssc.binaryRecordsStream(directory, recordLength), self, NoOpSerializer() ) - def _check_serializers(self, rdds): + def _check_serializers(self, rdds: List[RDD[T]]) -> None: # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: for i in range(len(rdds)): # reset them to sc.serializer rdds[i] = rdds[i]._reserialize() - def queueStream(self, rdds, oneAtATime=True, default=None): + def queueStream( + self, + rdds: List[RDD[T]], + oneAtATime: bool = True, + default: Optional[RDD[T]] = None, + ) -> "DStream[T]": """ Create an input stream from a queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. @@ -339,42 +369,48 @@ class StreamingContext: Changes to the queue after the stream is created will not be recognized. """ if default and not isinstance(default, RDD): - default = self._sc.parallelize(default) + default = self._sc.parallelize(default) # type: ignore[arg-type] if not rdds and default: - rdds = [rdds] + rdds = [rdds] # type: ignore[list-item] if rdds and not isinstance(rdds[0], RDD): - rdds = [self._sc.parallelize(input) for input in rdds] + rdds = [self._sc.parallelize(input) for input in rdds] # type: ignore[arg-type] self._check_serializers(rdds) + assert self._jvm is not None queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) + assert default is not None jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) else: jdstream = self._jssc.queueStream(queue, oneAtATime) return DStream(jdstream, self, rdds[0]._jrdd_deserializer) - def transform(self, dstreams, transformFunc): + def transform( + self, dstreams: List["DStream[Any]"], transformFunc: Callable[..., RDD[T]] + ) -> "DStream[T]": """ Create a new DStream in which each RDD is generated by applying a function on RDDs of the DStreams. The order of the JavaRDDs in the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = [d._jdstream for d in dstreams] + jdstreams = [d._jdstream for d in dstreams] # type: ignore[attr-defined] # change the final serializer to sc.serializer func = TransformFunction( self._sc, lambda t, *rdds: transformFunc(rdds), - *[d._jrdd_deserializer for d in dstreams], + *[d._jrdd_deserializer for d in dstreams], # type: ignore[attr-defined] ) + + assert self._jvm is not None jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) return DStream(jdstream, self, self._sc.serializer) - def union(self, *dstreams): + def union(self, *dstreams: "DStream[T]") -> "DStream[T]": """ Create a unified DStream from multiple DStreams of the same type and same slide duration. @@ -383,30 +419,43 @@ class StreamingContext: raise ValueError("should have at least one DStream to union") if len(dstreams) == 1: return dstreams[0] - if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: # type: ignore[attr-defined] raise ValueError("All DStreams should have same serializer") - if len(set(s._slideDuration for s in dstreams)) > 1: + if len(set(s._slideDuration for s in dstreams)) > 1: # type: ignore[attr-defined] raise ValueError("All DStreams should have same slide duration") + + assert SparkContext._jvm is not None jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream gw = SparkContext._gateway - if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): + if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): # type: ignore[attr-defined] cls = jdstream_cls - elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls): + elif is_instance_of( + gw, dstreams[0]._jdstream, jpair_dstream_cls # type: ignore[attr-defined] + ): cls = jpair_dstream_cls else: - cls_name = dstreams[0]._jdstream.getClass().getCanonicalName() + cls_name = ( + dstreams[0]._jdstream.getClass().getCanonicalName() # type: ignore[attr-defined] + ) raise TypeError("Unsupported Java DStream class %s" % cls_name) + + assert gw is not None jdstreams = gw.new_array(cls, len(dstreams)) for i in range(0, len(dstreams)): - jdstreams[i] = dstreams[i]._jdstream - return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer) + jdstreams[i] = dstreams[i]._jdstream # type: ignore[attr-defined] + return DStream( + self._jssc.union(jdstreams), + self, + dstreams[0]._jrdd_deserializer, # type: ignore[attr-defined] + ) - def addStreamingListener(self, streamingListener): + def addStreamingListener(self, streamingListener: StreamingListener) -> None: """ Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for receiving system events related to streaming. """ + assert self._jvm is not None self._jssc.addStreamingListener( self._jvm.JavaStreamingListenerWrapper( self._jvm.PythonStreamingListenerWrapper(streamingListener) diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi deleted file mode 100644 index 0d1b2aca739..00000000000 --- a/python/pyspark/streaming/context.pyi +++ /dev/null @@ -1,71 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Callable, List, Optional, TypeVar - -from py4j.java_gateway import JavaObject - -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.storagelevel import StorageLevel -from pyspark.streaming.dstream import DStream -from pyspark.streaming.listener import StreamingListener - -T = TypeVar("T") - -class StreamingContext: - def __init__( - self, - sparkContext: SparkContext, - batchDuration: int = ..., - jssc: Optional[JavaObject] = ..., - ) -> None: ... - @classmethod - def getOrCreate( - cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext] - ) -> StreamingContext: ... - @classmethod - def getActive(cls) -> StreamingContext: ... - @classmethod - def getActiveOrCreate( - cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext] - ) -> StreamingContext: ... - @property - def sparkContext(self) -> SparkContext: ... - def start(self) -> None: ... - def awaitTermination(self, timeout: Optional[int] = ...) -> None: ... - def awaitTerminationOrTimeout(self, timeout: int) -> None: ... - def stop(self, stopSparkContext: bool = ..., stopGraceFully: bool = ...) -> None: ... - def remember(self, duration: int) -> None: ... - def checkpoint(self, directory: str) -> None: ... - def socketTextStream( - self, hostname: str, port: int, storageLevel: StorageLevel = ... - ) -> DStream[str]: ... - def textFileStream(self, directory: str) -> DStream[str]: ... - def binaryRecordsStream(self, directory: str, recordLength: int) -> DStream[bytes]: ... - def queueStream( - self, - rdds: List[RDD[T]], - oneAtATime: bool = ..., - default: Optional[RDD[T]] = ..., - ) -> DStream[T]: ... - def transform( - self, dstreams: List[DStream[Any]], transformFunc: Callable[..., RDD[T]] - ) -> DStream[T]: ... - def union(self, *dstreams: DStream[T]) -> DStream[T]: ... - def addStreamingListener(self, streamingListener: StreamingListener) -> None: ... diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 26d66c394ab..150fb79f572 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -153,10 +153,11 @@ class KinesisUtils: The given AWS credentials will get saved in DStream checkpoints if checkpointing is enabled. Make sure that your checkpoint directory is secure. """ - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) # type: ignore[attr-defined] - jduration = ssc._jduration(checkpointInterval) # type: ignore[attr-defined] + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + jduration = ssc._jduration(checkpointInterval) - jvm = ssc._jvm # type: ignore[attr-defined] + jvm = ssc._jvm + assert jvm is not None try: helper = jvm.org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper() @@ -170,7 +171,7 @@ class KinesisUtils: ) raise jstream = helper.createStream( - ssc._jssc, # type: ignore[attr-defined] + ssc._jssc, kinesisAppName, streamName, endpointUrl, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org