This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c0c1f35cd92 [SPARK-37014][PYTHON] Inline type hints for 
python/pyspark/streaming/context.py
c0c1f35cd92 is described below

commit c0c1f35cd9279bc1a7a50119be72a297162a9b55
Author: dch nguyen <dchvn....@gmail.com>
AuthorDate: Thu Apr 14 02:03:24 2022 +0200

    [SPARK-37014][PYTHON] Inline type hints for 
python/pyspark/streaming/context.py
    
    ### What changes were proposed in this pull request?
    Inline type hints for python/pyspark/streaming/context.py from Inline type 
hints for python/pyspark/streaming/context.pyi.
    
    ### Why are the changes needed?
    Currently, there is type hint stub files 
python/pyspark/streaming/context.pyi to show the expected types for functions, 
but we can also take advantage of static type checking within the functions by 
inlining the type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing test.
    
    Closes #34293 from dchvn/SPARK-37014.
    
    Authored-by: dch nguyen <dchvn....@gmail.com>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 python/pyspark/streaming/context.py  | 123 ++++++++++++++++++++++++-----------
 python/pyspark/streaming/context.pyi |  71 --------------------
 python/pyspark/streaming/kinesis.py  |   9 +--
 3 files changed, 91 insertions(+), 112 deletions(-)

diff --git a/python/pyspark/streaming/context.py 
b/python/pyspark/streaming/context.py
index cc9875d6575..52e5efed063 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -14,18 +14,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from typing import Any, Callable, List, Optional, TypeVar
 
-from py4j.java_gateway import java_import, is_instance_of
+from py4j.java_gateway import java_import, is_instance_of, JavaObject
 
 from pyspark import RDD, SparkConf
 from pyspark.serializers import NoOpSerializer, UTF8Deserializer, 
CloudPickleSerializer
 from pyspark.context import SparkContext
 from pyspark.storagelevel import StorageLevel
 from pyspark.streaming.dstream import DStream
+from pyspark.streaming.listener import StreamingListener
 from pyspark.streaming.util import TransformFunction, 
TransformFunctionSerializer
 
 __all__ = ["StreamingContext"]
 
+T = TypeVar("T")
+
 
 class StreamingContext:
     """
@@ -51,27 +55,35 @@ class StreamingContext:
     # Reference to a currently active StreamingContext
     _activeContext = None
 
-    def __init__(self, sparkContext, batchDuration=None, jssc=None):
-
+    def __init__(
+        self,
+        sparkContext: SparkContext,
+        batchDuration: Optional[int] = None,
+        jssc: Optional[JavaObject] = None,
+    ):
         self._sc = sparkContext
         self._jvm = self._sc._jvm
         self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
 
-    def _initialize_context(self, sc, duration):
+    def _initialize_context(self, sc: SparkContext, duration: Optional[int]) 
-> JavaObject:
         self._ensure_initialized()
+        assert self._jvm is not None and duration is not None
         return self._jvm.JavaStreamingContext(sc._jsc, 
self._jduration(duration))
 
-    def _jduration(self, seconds):
+    def _jduration(self, seconds: int) -> JavaObject:
         """
         Create Duration object given number of seconds
         """
+        assert self._jvm is not None
         return self._jvm.Duration(int(seconds * 1000))
 
     @classmethod
-    def _ensure_initialized(cls):
+    def _ensure_initialized(cls) -> None:
         SparkContext._ensure_initialized()
         gw = SparkContext._gateway
 
+        assert gw is not None
+
         java_import(gw.jvm, "org.apache.spark.streaming.*")
         java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
         java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
@@ -83,11 +95,15 @@ class StreamingContext:
         # register serializer for TransformFunction
         # it happens before creating SparkContext when loading from 
checkpointing
         cls._transformerSerializer = TransformFunctionSerializer(
-            SparkContext._active_spark_context, CloudPickleSerializer(), gw
+            SparkContext._active_spark_context,
+            CloudPickleSerializer(),
+            gw,
         )
 
     @classmethod
-    def getOrCreate(cls, checkpointPath, setupFunc):
+    def getOrCreate(
+        cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"]
+    ) -> "StreamingContext":
         """
         Either recreate a StreamingContext from checkpoint data or create a 
new StreamingContext.
         If checkpoint data exists in the provided `checkpointPath`, then 
StreamingContext will be
@@ -104,6 +120,8 @@ class StreamingContext:
         cls._ensure_initialized()
         gw = SparkContext._gateway
 
+        assert gw is not None
+
         # Check whether valid checkpoint information exists in the given path
         ssc_option = 
gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
         if ssc_option.isEmpty():
@@ -121,12 +139,15 @@ class StreamingContext:
 
         sc = SparkContext._active_spark_context
 
+        assert sc is not None
+
         # update ctx in serializer
+        assert cls._transformerSerializer is not None
         cls._transformerSerializer.ctx = sc
         return StreamingContext(sc, None, jssc)
 
     @classmethod
-    def getActive(cls):
+    def getActive(cls) -> Optional["StreamingContext"]:
         """
         Return either the currently active StreamingContext (i.e., if there is 
a context started
         but not stopped) or None.
@@ -149,7 +170,9 @@ class StreamingContext:
         return cls._activeContext
 
     @classmethod
-    def getActiveOrCreate(cls, checkpointPath, setupFunc):
+    def getActiveOrCreate(
+        cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"]
+    ) -> "StreamingContext":
         """
         Either return the active StreamingContext (i.e. currently started but 
not stopped),
         or recreate a StreamingContext from checkpoint data or create a new 
StreamingContext
@@ -178,20 +201,20 @@ class StreamingContext:
             return setupFunc()
 
     @property
-    def sparkContext(self):
+    def sparkContext(self) -> SparkContext:
         """
         Return SparkContext which is associated with this StreamingContext.
         """
         return self._sc
 
-    def start(self):
+    def start(self) -> None:
         """
         Start the execution of the streams.
         """
         self._jssc.start()
         StreamingContext._activeContext = self
 
-    def awaitTermination(self, timeout=None):
+    def awaitTermination(self, timeout: Optional[int] = None) -> None:
         """
         Wait for the execution to stop.
 
@@ -205,7 +228,7 @@ class StreamingContext:
         else:
             self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
 
-    def awaitTerminationOrTimeout(self, timeout):
+    def awaitTerminationOrTimeout(self, timeout: int) -> None:
         """
         Wait for the execution to stop. Return `true` if it's stopped; or
         throw the reported error during the execution; or `false` if the
@@ -218,7 +241,7 @@ class StreamingContext:
         """
         return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
 
-    def stop(self, stopSparkContext=True, stopGraceFully=False):
+    def stop(self, stopSparkContext: bool = True, stopGraceFully: bool = 
False) -> None:
         """
         Stop the execution of the streams, with option of ensuring all
         received data has been processed.
@@ -236,7 +259,7 @@ class StreamingContext:
         if stopSparkContext:
             self._sc.stop()
 
-    def remember(self, duration):
+    def remember(self, duration: int) -> None:
         """
         Set each DStreams in this context to remember RDDs it generated
         in the last given duration. DStreams remember RDDs only for a
@@ -252,7 +275,7 @@ class StreamingContext:
         """
         self._jssc.remember(self._jduration(duration))
 
-    def checkpoint(self, directory):
+    def checkpoint(self, directory: str) -> None:
         """
         Sets the context to periodically checkpoint the DStream operations for 
master
         fault-tolerance. The graph will be checkpointed every batch interval.
@@ -264,7 +287,9 @@ class StreamingContext:
         """
         self._jssc.checkpoint(directory)
 
-    def socketTextStream(self, hostname, port, 
storageLevel=StorageLevel.MEMORY_AND_DISK_2):
+    def socketTextStream(
+        self, hostname: str, port: int, storageLevel: StorageLevel = 
StorageLevel.MEMORY_AND_DISK_2
+    ) -> "DStream[str]":
         """
         Create an input from TCP source hostname:port. Data is received using
         a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` 
delimited
@@ -284,7 +309,7 @@ class StreamingContext:
             self._jssc.socketTextStream(hostname, port, jlevel), self, 
UTF8Deserializer()
         )
 
-    def textFileStream(self, directory):
+    def textFileStream(self, directory: str) -> "DStream[str]":
         """
         Create an input stream that monitors a Hadoop-compatible file system
         for new files and reads them as text files. Files must be written to 
the
@@ -294,7 +319,7 @@ class StreamingContext:
         """
         return DStream(self._jssc.textFileStream(directory), self, 
UTF8Deserializer())
 
-    def binaryRecordsStream(self, directory, recordLength):
+    def binaryRecordsStream(self, directory: str, recordLength: int) -> 
"DStream[bytes]":
         """
         Create an input stream that monitors a Hadoop-compatible file system
         for new files and reads them as flat binary files with records of
@@ -313,14 +338,19 @@ class StreamingContext:
             self._jssc.binaryRecordsStream(directory, recordLength), self, 
NoOpSerializer()
         )
 
-    def _check_serializers(self, rdds):
+    def _check_serializers(self, rdds: List[RDD[T]]) -> None:
         # make sure they have same serializer
         if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
             for i in range(len(rdds)):
                 # reset them to sc.serializer
                 rdds[i] = rdds[i]._reserialize()
 
-    def queueStream(self, rdds, oneAtATime=True, default=None):
+    def queueStream(
+        self,
+        rdds: List[RDD[T]],
+        oneAtATime: bool = True,
+        default: Optional[RDD[T]] = None,
+    ) -> "DStream[T]":
         """
         Create an input stream from a queue of RDDs or list. In each batch,
         it will process either one or all of the RDDs returned by the queue.
@@ -339,42 +369,48 @@ class StreamingContext:
         Changes to the queue after the stream is created will not be 
recognized.
         """
         if default and not isinstance(default, RDD):
-            default = self._sc.parallelize(default)
+            default = self._sc.parallelize(default)  # type: ignore[arg-type]
 
         if not rdds and default:
-            rdds = [rdds]
+            rdds = [rdds]  # type: ignore[list-item]
 
         if rdds and not isinstance(rdds[0], RDD):
-            rdds = [self._sc.parallelize(input) for input in rdds]
+            rdds = [self._sc.parallelize(input) for input in rdds]  # type: 
ignore[arg-type]
         self._check_serializers(rdds)
 
+        assert self._jvm is not None
         queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
         if default:
             default = default._reserialize(rdds[0]._jrdd_deserializer)
+            assert default is not None
             jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
         else:
             jdstream = self._jssc.queueStream(queue, oneAtATime)
         return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
 
-    def transform(self, dstreams, transformFunc):
+    def transform(
+        self, dstreams: List["DStream[Any]"], transformFunc: Callable[..., 
RDD[T]]
+    ) -> "DStream[T]":
         """
         Create a new DStream in which each RDD is generated by applying
         a function on RDDs of the DStreams. The order of the JavaRDDs in
         the transform function parameter will be the same as the order
         of corresponding DStreams in the list.
         """
-        jdstreams = [d._jdstream for d in dstreams]
+        jdstreams = [d._jdstream for d in dstreams]  # type: 
ignore[attr-defined]
         # change the final serializer to sc.serializer
         func = TransformFunction(
             self._sc,
             lambda t, *rdds: transformFunc(rdds),
-            *[d._jrdd_deserializer for d in dstreams],
+            *[d._jrdd_deserializer for d in dstreams],  # type: 
ignore[attr-defined]
         )
+
+        assert self._jvm is not None
         jfunc = self._jvm.TransformFunction(func)
         jdstream = self._jssc.transform(jdstreams, jfunc)
         return DStream(jdstream, self, self._sc.serializer)
 
-    def union(self, *dstreams):
+    def union(self, *dstreams: "DStream[T]") -> "DStream[T]":
         """
         Create a unified DStream from multiple DStreams of the same
         type and same slide duration.
@@ -383,30 +419,43 @@ class StreamingContext:
             raise ValueError("should have at least one DStream to union")
         if len(dstreams) == 1:
             return dstreams[0]
-        if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
+        if len(set(s._jrdd_deserializer for s in dstreams)) > 1:  # type: 
ignore[attr-defined]
             raise ValueError("All DStreams should have same serializer")
-        if len(set(s._slideDuration for s in dstreams)) > 1:
+        if len(set(s._slideDuration for s in dstreams)) > 1:  # type: 
ignore[attr-defined]
             raise ValueError("All DStreams should have same slide duration")
+
+        assert SparkContext._jvm is not None
         jdstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
         jpair_dstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
         gw = SparkContext._gateway
-        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
+        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):  # type: 
ignore[attr-defined]
             cls = jdstream_cls
-        elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
+        elif is_instance_of(
+            gw, dstreams[0]._jdstream, jpair_dstream_cls  # type: 
ignore[attr-defined]
+        ):
             cls = jpair_dstream_cls
         else:
-            cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
+            cls_name = (
+                dstreams[0]._jdstream.getClass().getCanonicalName()  # type: 
ignore[attr-defined]
+            )
             raise TypeError("Unsupported Java DStream class %s" % cls_name)
+
+        assert gw is not None
         jdstreams = gw.new_array(cls, len(dstreams))
         for i in range(0, len(dstreams)):
-            jdstreams[i] = dstreams[i]._jdstream
-        return DStream(self._jssc.union(jdstreams), self, 
dstreams[0]._jrdd_deserializer)
+            jdstreams[i] = dstreams[i]._jdstream  # type: ignore[attr-defined]
+        return DStream(
+            self._jssc.union(jdstreams),
+            self,
+            dstreams[0]._jrdd_deserializer,  # type: ignore[attr-defined]
+        )
 
-    def addStreamingListener(self, streamingListener):
+    def addStreamingListener(self, streamingListener: StreamingListener) -> 
None:
         """
         Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] 
object for
         receiving system events related to streaming.
         """
+        assert self._jvm is not None
         self._jssc.addStreamingListener(
             self._jvm.JavaStreamingListenerWrapper(
                 self._jvm.PythonStreamingListenerWrapper(streamingListener)
diff --git a/python/pyspark/streaming/context.pyi 
b/python/pyspark/streaming/context.pyi
deleted file mode 100644
index 0d1b2aca739..00000000000
--- a/python/pyspark/streaming/context.pyi
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Any, Callable, List, Optional, TypeVar
-
-from py4j.java_gateway import JavaObject
-
-from pyspark.context import SparkContext
-from pyspark.rdd import RDD
-from pyspark.storagelevel import StorageLevel
-from pyspark.streaming.dstream import DStream
-from pyspark.streaming.listener import StreamingListener
-
-T = TypeVar("T")
-
-class StreamingContext:
-    def __init__(
-        self,
-        sparkContext: SparkContext,
-        batchDuration: int = ...,
-        jssc: Optional[JavaObject] = ...,
-    ) -> None: ...
-    @classmethod
-    def getOrCreate(
-        cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext]
-    ) -> StreamingContext: ...
-    @classmethod
-    def getActive(cls) -> StreamingContext: ...
-    @classmethod
-    def getActiveOrCreate(
-        cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext]
-    ) -> StreamingContext: ...
-    @property
-    def sparkContext(self) -> SparkContext: ...
-    def start(self) -> None: ...
-    def awaitTermination(self, timeout: Optional[int] = ...) -> None: ...
-    def awaitTerminationOrTimeout(self, timeout: int) -> None: ...
-    def stop(self, stopSparkContext: bool = ..., stopGraceFully: bool = ...) 
-> None: ...
-    def remember(self, duration: int) -> None: ...
-    def checkpoint(self, directory: str) -> None: ...
-    def socketTextStream(
-        self, hostname: str, port: int, storageLevel: StorageLevel = ...
-    ) -> DStream[str]: ...
-    def textFileStream(self, directory: str) -> DStream[str]: ...
-    def binaryRecordsStream(self, directory: str, recordLength: int) -> 
DStream[bytes]: ...
-    def queueStream(
-        self,
-        rdds: List[RDD[T]],
-        oneAtATime: bool = ...,
-        default: Optional[RDD[T]] = ...,
-    ) -> DStream[T]: ...
-    def transform(
-        self, dstreams: List[DStream[Any]], transformFunc: Callable[..., 
RDD[T]]
-    ) -> DStream[T]: ...
-    def union(self, *dstreams: DStream[T]) -> DStream[T]: ...
-    def addStreamingListener(self, streamingListener: StreamingListener) -> 
None: ...
diff --git a/python/pyspark/streaming/kinesis.py 
b/python/pyspark/streaming/kinesis.py
index 26d66c394ab..150fb79f572 100644
--- a/python/pyspark/streaming/kinesis.py
+++ b/python/pyspark/streaming/kinesis.py
@@ -153,10 +153,11 @@ class KinesisUtils:
         The given AWS credentials will get saved in DStream checkpoints if 
checkpointing
         is enabled. Make sure that your checkpoint directory is secure.
         """
-        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)  # type: 
ignore[attr-defined]
-        jduration = ssc._jduration(checkpointInterval)  # type: 
ignore[attr-defined]
+        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+        jduration = ssc._jduration(checkpointInterval)
 
-        jvm = ssc._jvm  # type: ignore[attr-defined]
+        jvm = ssc._jvm
+        assert jvm is not None
 
         try:
             helper = 
jvm.org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper()
@@ -170,7 +171,7 @@ class KinesisUtils:
                 )
             raise
         jstream = helper.createStream(
-            ssc._jssc,  # type: ignore[attr-defined]
+            ssc._jssc,
             kinesisAppName,
             streamName,
             endpointUrl,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to