This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e3b4d1e  [SPARK-36991][PYTHON] Inline type hints for 
spark/python/pyspark/sql/streaming.py
e3b4d1e is described below

commit e3b4d1eef37704f8ebb5bc669ad041166cc674f0
Author: Xinrong Meng <xinrong.m...@databricks.com>
AuthorDate: Fri Oct 15 12:05:08 2021 -0700

    [SPARK-36991][PYTHON] Inline type hints for 
spark/python/pyspark/sql/streaming.py
    
    ### What changes were proposed in this pull request?
    Inline type hints for spark/python/pyspark/sql/streaming.py
    
    ### Why are the changes needed?
    Current stub files cannot support type checking within function bodies. So 
we wanted to inline type hints for spark/python/pyspark/sql/streaming.py.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing test.
    
    Closes #34277 from xinrong-databricks/inline_streaming.
    
    Authored-by: Xinrong Meng <xinrong.m...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/streaming.py  | 355 +++++++++++++++++++++++++++------------
 python/pyspark/sql/streaming.pyi | 197 ----------------------
 2 files changed, 250 insertions(+), 302 deletions(-)

diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 08c8934..24cd2db 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -17,15 +17,22 @@
 
 import sys
 import json
+from collections.abc import Iterator
+from typing import cast, overload, Any, Callable, Dict, List, Optional, 
TYPE_CHECKING, Union
 
-from py4j.java_gateway import java_import
+from py4j.java_gateway import java_import, JavaObject
 
-from pyspark import since, keyword_only
+from pyspark import since
 from pyspark.sql.column import _to_seq
-from pyspark.sql.readwriter import OptionUtils, to_str
-from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.sql.readwriter import OptionUtils, to_str  # type: 
ignore[attr-defined]
+from pyspark.sql.types import Row, StructType, StructField, StringType
 from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
 
+if TYPE_CHECKING:
+    from pyspark.sql import SQLContext
+    from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType
+    from pyspark.sql.dataframe import DataFrame
+
 __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", 
"DataStreamWriter"]
 
 
@@ -41,12 +48,12 @@ class StreamingQuery(object):
     This API is evolving.
     """
 
-    def __init__(self, jsq):
+    def __init__(self, jsq: JavaObject) -> None:
         self._jsq = jsq
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.0)
-    def id(self):
+    def id(self) -> str:
         """Returns the unique id of this query that persists across restarts 
from checkpoint data.
         That is, this id is generated when a query is started for the first 
time, and
         will be the same every time it is restarted from checkpoint data.
@@ -55,17 +62,17 @@ class StreamingQuery(object):
         """
         return self._jsq.id().toString()
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.1)
-    def runId(self):
+    def runId(self) -> str:
         """Returns the unique id of this query that does not persist across 
restarts. That is, every
         query that is started (or restarted from checkpoint) will have a 
different runId.
         """
         return self._jsq.runId().toString()
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.0)
-    def name(self):
+    def name(self) -> str:
         """Returns the user-specified name of the query, or null if not 
specified.
         This name can be specified in the 
`org.apache.spark.sql.streaming.DataStreamWriter`
         as `dataframe.writeStream.queryName("query").start()`.
@@ -73,15 +80,15 @@ class StreamingQuery(object):
         """
         return self._jsq.name()
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.0)
-    def isActive(self):
+    def isActive(self) -> bool:
         """Whether this streaming query is currently active or not.
         """
         return self._jsq.isActive()
 
     @since(2.0)
-    def awaitTermination(self, timeout=None):
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
         """Waits for the termination of `this` query, either by 
:func:`query.stop()` or by an
         exception. If the query has terminated with an exception, then the 
exception will be thrown.
         If `timeout` is set, it returns whether the query has terminated or 
not within the
@@ -100,17 +107,17 @@ class StreamingQuery(object):
         else:
             return self._jsq.awaitTermination()
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.1)
-    def status(self):
+    def status(self) -> Dict[str, Any]:
         """
         Returns the current status of the query.
         """
         return json.loads(self._jsq.status().json())
 
-    @property
+    @property  # type: ignore[misc]
     @since(2.1)
-    def recentProgress(self):
+    def recentProgress(self) -> List[Dict[str, Any]]:
         """Returns an array of the most recent [[StreamingQueryProgress]] 
updates for this query.
         The number of progress updates retained for each stream is configured 
by Spark session
         configuration `spark.sql.streaming.numRecentProgressUpdates`.
@@ -118,7 +125,7 @@ class StreamingQuery(object):
         return [json.loads(p.json()) for p in self._jsq.recentProgress()]
 
     @property
-    def lastProgress(self):
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
         """
         Returns the most recent :class:`StreamingQueryProgress` update of this 
streaming query or
         None if there were no progress updates
@@ -135,7 +142,7 @@ class StreamingQuery(object):
         else:
             return None
 
-    def processAllAvailable(self):
+    def processAllAvailable(self) -> None:
         """Blocks until all available data in the source has been processed 
and committed to the
         sink. This method is intended for testing.
 
@@ -151,12 +158,12 @@ class StreamingQuery(object):
         return self._jsq.processAllAvailable()
 
     @since(2.0)
-    def stop(self):
+    def stop(self) -> None:
         """Stop this streaming query.
         """
         self._jsq.stop()
 
-    def explain(self, extended=False):
+    def explain(self, extended: bool = False) -> None:
         """Prints the (logical and physical) plans to the console for 
debugging purpose.
 
         .. versionadded:: 2.1.0
@@ -188,7 +195,7 @@ class StreamingQuery(object):
         # We should print it in the Python process.
         print(self._jsq.explainInternal(extended))
 
-    def exception(self):
+    def exception(self) -> Optional[StreamingQueryException]:
         """
         .. versionadded:: 2.1.0
 
@@ -216,11 +223,11 @@ class StreamingQueryManager(object):
     This API is evolving.
     """
 
-    def __init__(self, jsqm):
+    def __init__(self, jsqm: JavaObject) -> None:
         self._jsqm = jsqm
 
     @property
-    def active(self):
+    def active(self) -> List[StreamingQuery]:
         """Returns a list of active queries associated with this SQLContext
 
         .. versionadded:: 2.0.0
@@ -236,7 +243,7 @@ class StreamingQueryManager(object):
         """
         return [StreamingQuery(jsq) for jsq in self._jsqm.active()]
 
-    def get(self, id):
+    def get(self, id: str) -> StreamingQuery:
         """Returns an active query from this SQLContext or throws exception if 
an active query
         with this name doesn't exist.
 
@@ -258,7 +265,7 @@ class StreamingQueryManager(object):
         return StreamingQuery(self._jsqm.get(id))
 
     @since(2.0)
-    def awaitAnyTermination(self, timeout=None):
+    def awaitAnyTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
         """Wait until any of the queries on the associated SQLContext has 
terminated since the
         creation of the context, or since :func:`resetTerminated()` was 
called. If any query was
         terminated with an exception, then the exception will be thrown.
@@ -285,7 +292,7 @@ class StreamingQueryManager(object):
         else:
             return self._jsqm.awaitAnyTermination()
 
-    def resetTerminated(self):
+    def resetTerminated(self) -> None:
         """Forget about past terminated queries so that 
:func:`awaitAnyTermination()` can be used
         again to wait for new terminations.
 
@@ -311,15 +318,15 @@ class DataStreamReader(OptionUtils):
     This API is evolving.
     """
 
-    def __init__(self, spark):
-        self._jreader = spark._ssql_ctx.readStream()
+    def __init__(self, spark: "SQLContext") -> None:
+        self._jreader = spark._ssql_ctx.readStream()  # type: 
ignore[attr-defined]
         self._spark = spark
 
-    def _df(self, jdf):
+    def _df(self, jdf: JavaObject) -> "DataFrame":
         from pyspark.sql.dataframe import DataFrame
         return DataFrame(jdf, self._spark)
 
-    def format(self, source):
+    def format(self, source: str) -> "DataStreamReader":
         """Specifies the input data source format.
 
         .. versionadded:: 2.0.0
@@ -340,7 +347,7 @@ class DataStreamReader(OptionUtils):
         self._jreader = self._jreader.format(source)
         return self
 
-    def schema(self, schema):
+    def schema(self, schema: Union[StructType, str]) -> "DataStreamReader":
         """Specifies the input schema.
 
         Some data sources (e.g. JSON) can infer the input schema automatically 
from data.
@@ -375,7 +382,7 @@ class DataStreamReader(OptionUtils):
             raise TypeError("schema should be StructType or string")
         return self
 
-    def option(self, key, value):
+    def option(self, key: str, value: "OptionalPrimitiveType") -> 
"DataStreamReader":
         """Adds an input option for the underlying data source.
 
         .. versionadded:: 2.0.0
@@ -391,7 +398,7 @@ class DataStreamReader(OptionUtils):
         self._jreader = self._jreader.option(key, to_str(value))
         return self
 
-    def options(self, **options):
+    def options(self, **options: "OptionalPrimitiveType") -> 
"DataStreamReader":
         """Adds input options for the underlying data source.
 
         .. versionadded:: 2.0.0
@@ -408,7 +415,13 @@ class DataStreamReader(OptionUtils):
             self._jreader = self._jreader.option(k, to_str(options[k]))
         return self
 
-    def load(self, path=None, format=None, schema=None, **options):
+    def load(
+        self,
+        path: Optional[str] = None,
+        format: Optional[str] = None,
+        schema: Optional[Union[StructType, str]] = None,
+        **options: "OptionalPrimitiveType"
+    ) -> "DataFrame":
         """Loads a data stream from a data source and returns it as a
         :class:`DataFrame <pyspark.sql.DataFrame>`.
 
@@ -453,13 +466,31 @@ class DataStreamReader(OptionUtils):
         else:
             return self._df(self._jreader.load())
 
-    def json(self, path, schema=None, primitivesAsString=None, 
prefersDecimal=None,
-             allowComments=None, allowUnquotedFieldNames=None, 
allowSingleQuotes=None,
-             allowNumericLeadingZero=None, 
allowBackslashEscapingAnyCharacter=None,
-             mode=None, columnNameOfCorruptRecord=None, dateFormat=None, 
timestampFormat=None,
-             multiLine=None,  allowUnquotedControlChars=None, lineSep=None, 
locale=None,
-             dropFieldIfAllNull=None, encoding=None, pathGlobFilter=None,
-             recursiveFileLookup=None, allowNonNumericNumbers=None):
+    def json(
+        self,
+        path: str,
+        schema: Optional[Union[StructType, str]] = None,
+        primitivesAsString: Optional[Union[bool, str]] = None,
+        prefersDecimal: Optional[Union[bool, str]] = None,
+        allowComments: Optional[Union[bool, str]] = None,
+        allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
+        allowSingleQuotes: Optional[Union[bool, str]] = None,
+        allowNumericLeadingZero: Optional[Union[bool, str]] = None,
+        allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
+        mode: Optional[str] = None,
+        columnNameOfCorruptRecord: Optional[str] = None,
+        dateFormat: Optional[str] = None,
+        timestampFormat: Optional[str] = None,
+        multiLine: Optional[Union[bool, str]] = None,
+        allowUnquotedControlChars: Optional[Union[bool, str]] = None,
+        lineSep: Optional[str] = None,
+        locale: Optional[str] = None,
+        dropFieldIfAllNull: Optional[Union[bool, str]] = None,
+        encoding: Optional[str] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+        allowNonNumericNumbers: Optional[Union[bool, str]] = None,
+    ) -> "DataFrame":
         """
         Loads a JSON file stream and returns the results as a 
:class:`DataFrame`.
 
@@ -501,7 +532,7 @@ class DataStreamReader(OptionUtils):
         >>> json_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(
+        self._set_opts(  # type: ignore[attr-defined]
             schema=schema, primitivesAsString=primitivesAsString, 
prefersDecimal=prefersDecimal,
             allowComments=allowComments, 
allowUnquotedFieldNames=allowUnquotedFieldNames,
             allowSingleQuotes=allowSingleQuotes, 
allowNumericLeadingZero=allowNumericLeadingZero,
@@ -517,7 +548,13 @@ class DataStreamReader(OptionUtils):
         else:
             raise TypeError("path can be only a single string")
 
-    def orc(self, path, mergeSchema=None, pathGlobFilter=None, 
recursiveFileLookup=None):
+    def orc(
+        self,
+        path: str,
+        mergeSchema: Optional[bool] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+    ) -> "DataFrame":
         """Loads a ORC file stream, returning the result as a 
:class:`DataFrame`.
 
         .. versionadded:: 2.3.0
@@ -539,15 +576,25 @@ class DataStreamReader(OptionUtils):
         >>> orc_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter,
-                       recursiveFileLookup=recursiveFileLookup)
+        self._set_opts(   # type: ignore[attr-defined]
+            mergeSchema=mergeSchema,
+            pathGlobFilter=pathGlobFilter,
+            recursiveFileLookup=recursiveFileLookup
+        )
         if isinstance(path, str):
             return self._df(self._jreader.orc(path))
         else:
             raise TypeError("path can be only a single string")
 
-    def parquet(self, path, mergeSchema=None, pathGlobFilter=None, 
recursiveFileLookup=None,
-                datetimeRebaseMode=None, int96RebaseMode=None):
+    def parquet(
+        self,
+        path: str,
+        mergeSchema: Optional[bool] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+        datetimeRebaseMode: Optional[Union[bool, str]] = None,
+        int96RebaseMode: Optional[Union[bool, str]] = None,
+    ) -> "DataFrame":
         """
         Loads a Parquet file stream, returning the result as a 
:class:`DataFrame`.
 
@@ -575,16 +622,26 @@ class DataStreamReader(OptionUtils):
         >>> parquet_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter,
-                       recursiveFileLookup=recursiveFileLookup,
-                       datetimeRebaseMode=datetimeRebaseMode, 
int96RebaseMode=int96RebaseMode)
+        self._set_opts(  # type: ignore[attr-defined]
+            mergeSchema=mergeSchema,
+            pathGlobFilter=pathGlobFilter,
+            recursiveFileLookup=recursiveFileLookup,
+            datetimeRebaseMode=datetimeRebaseMode,
+            int96RebaseMode=int96RebaseMode
+        )
         if isinstance(path, str):
             return self._df(self._jreader.parquet(path))
         else:
             raise TypeError("path can be only a single string")
 
-    def text(self, path, wholetext=False, lineSep=None, pathGlobFilter=None,
-             recursiveFileLookup=None):
+    def text(
+        self,
+        path: str,
+        wholetext: bool = False,
+        lineSep: Optional[str] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+    ) -> "DataFrame":
         """
         Loads a text file stream and returns a :class:`DataFrame` whose schema 
starts with a
         string column named "value", and followed by partitioned columns if 
there
@@ -597,7 +654,7 @@ class DataStreamReader(OptionUtils):
 
         Parameters
         ----------
-        paths : str or list
+        path : str or list
             string, or list of strings, for input path(s).
 
         Other Parameters
@@ -621,7 +678,7 @@ class DataStreamReader(OptionUtils):
         >>> "value" in str(text_sdf.schema)
         True
         """
-        self._set_opts(
+        self._set_opts(  # type: ignore[attr-defined]
             wholetext=wholetext, lineSep=lineSep, 
pathGlobFilter=pathGlobFilter,
             recursiveFileLookup=recursiveFileLookup)
         if isinstance(path, str):
@@ -629,14 +686,40 @@ class DataStreamReader(OptionUtils):
         else:
             raise TypeError("path can be only a single string")
 
-    def csv(self, path, schema=None, sep=None, encoding=None, quote=None, 
escape=None,
-            comment=None, header=None, inferSchema=None, 
ignoreLeadingWhiteSpace=None,
-            ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, 
positiveInf=None,
-            negativeInf=None, dateFormat=None, timestampFormat=None, 
maxColumns=None,
-            maxCharsPerColumn=None, maxMalformedLogPerPartition=None, 
mode=None,
-            columnNameOfCorruptRecord=None, multiLine=None, 
charToEscapeQuoteEscaping=None,
-            enforceSchema=None, emptyValue=None, locale=None, lineSep=None,
-            pathGlobFilter=None, recursiveFileLookup=None, 
unescapedQuoteHandling=None):
+    def csv(
+        self,
+        path: str,
+        schema: Optional[Union[StructType, str]] = None,
+        sep: Optional[str] = None,
+        encoding: Optional[str] = None,
+        quote: Optional[str] = None,
+        escape: Optional[str] = None,
+        comment: Optional[str] = None,
+        header: Optional[Union[bool, str]] = None,
+        inferSchema: Optional[Union[bool, str]] = None,
+        ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
+        ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
+        nullValue: Optional[str] = None,
+        nanValue: Optional[str] = None,
+        positiveInf: Optional[str] = None,
+        negativeInf: Optional[str] = None,
+        dateFormat: Optional[str] = None,
+        timestampFormat: Optional[str] = None,
+        maxColumns: Optional[Union[int, str]] = None,
+        maxCharsPerColumn: Optional[Union[int, str]] = None,
+        maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
+        mode: Optional[str] = None,
+        columnNameOfCorruptRecord: Optional[str] = None,
+        multiLine: Optional[Union[bool, str]] = None,
+        charToEscapeQuoteEscaping: Optional[Union[bool, str]] = None,
+        enforceSchema: Optional[Union[bool, str]] = None,
+        emptyValue: Optional[str] = None,
+        locale: Optional[str] = None,
+        lineSep: Optional[str] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+        unescapedQuoteHandling: Optional[str] = None,
+    ) -> "DataFrame":
         r"""Loads a CSV file stream and returns the result as a 
:class:`DataFrame`.
 
         This function will go through the input once to determine the input 
schema if
@@ -674,7 +757,7 @@ class DataStreamReader(OptionUtils):
         >>> csv_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(
+        self._set_opts(  # type: ignore[attr-defined]
             schema=schema, sep=sep, encoding=encoding, quote=quote, 
escape=escape, comment=comment,
             header=header, inferSchema=inferSchema, 
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
             ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, 
nullValue=nullValue,
@@ -692,7 +775,7 @@ class DataStreamReader(OptionUtils):
         else:
             raise TypeError("path can be only a single string")
 
-    def table(self, tableName):
+    def table(self, tableName: str) -> "DataFrame":
         """Define a Streaming DataFrame on a Table. The DataSource 
corresponding to the table should
         support streaming mode.
 
@@ -735,16 +818,16 @@ class DataStreamWriter(object):
     This API is evolving.
     """
 
-    def __init__(self, df):
+    def __init__(self, df: "DataFrame") -> None:
         self._df = df
         self._spark = df.sql_ctx
         self._jwrite = df._jdf.writeStream()
 
-    def _sq(self, jsq):
+    def _sq(self, jsq: JavaObject) -> StreamingQuery:
         from pyspark.sql.streaming import StreamingQuery
         return StreamingQuery(jsq)
 
-    def outputMode(self, outputMode):
+    def outputMode(self, outputMode: str) -> "DataStreamWriter":
         """Specifies how data of a streaming DataFrame/Dataset is written to a 
streaming sink.
 
         .. versionadded:: 2.0.0
@@ -772,7 +855,7 @@ class DataStreamWriter(object):
         self._jwrite = self._jwrite.outputMode(outputMode)
         return self
 
-    def format(self, source):
+    def format(self, source: str) -> "DataStreamWriter":
         """Specifies the underlying output data source.
 
         .. versionadded:: 2.0.0
@@ -793,7 +876,7 @@ class DataStreamWriter(object):
         self._jwrite = self._jwrite.format(source)
         return self
 
-    def option(self, key, value):
+    def option(self, key: str, value: "OptionalPrimitiveType") -> 
"DataStreamWriter":
         """Adds an output option for the underlying data source.
 
         .. versionadded:: 2.0.0
@@ -805,7 +888,7 @@ class DataStreamWriter(object):
         self._jwrite = self._jwrite.option(key, to_str(value))
         return self
 
-    def options(self, **options):
+    def options(self, **options: "OptionalPrimitiveType") -> 
"DataStreamWriter":
         """Adds output options for the underlying data source.
 
         .. versionadded:: 2.0.0
@@ -818,7 +901,15 @@ class DataStreamWriter(object):
             self._jwrite = self._jwrite.option(k, to_str(options[k]))
         return self
 
-    def partitionBy(self, *cols):
+    @overload
+    def partitionBy(self, *cols: str) -> "DataStreamWriter":
+        ...
+
+    @overload
+    def partitionBy(self, __cols: List[str]) -> "DataStreamWriter":
+        ...
+
+    def partitionBy(self, *cols: str) -> "DataStreamWriter":  # type: 
ignore[misc]
         """Partitions the output by the given columns on the file system.
 
         If specified, the output is laid out on the file system similar
@@ -837,10 +928,11 @@ class DataStreamWriter(object):
         """
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]
-        self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
+        self._jwrite = self._jwrite.partitionBy(
+            _to_seq(self._spark._sc, cols))  # type: ignore[attr-defined]
         return self
 
-    def queryName(self, queryName):
+    def queryName(self, queryName: str) -> "DataStreamWriter":
         """Specifies the name of the :class:`StreamingQuery` that can be 
started with
         :func:`start`. This name must be unique among all the currently active 
queries
         in the associated SparkSession.
@@ -865,8 +957,25 @@ class DataStreamWriter(object):
         self._jwrite = self._jwrite.queryName(queryName)
         return self
 
-    @keyword_only
-    def trigger(self, *, processingTime=None, once=None, continuous=None):
+    @overload
+    def trigger(self, *, processingTime: str) -> "DataStreamWriter":
+        ...
+
+    @overload
+    def trigger(self, *, once: bool) -> "DataStreamWriter":
+        ...
+
+    @overload
+    def trigger(self, *, continuous: str) -> "DataStreamWriter":
+        ...
+
+    def trigger(
+        self,
+        *,
+        processingTime: Optional[str] = None,
+        once: Optional[bool] = None,
+        continuous: Optional[str] = None,
+    ) -> "DataStreamWriter":
         """Set the trigger for the stream query. If this is not set it will 
run the query as fast
         as possible, which is equivalent to setting the trigger to 
``processingTime='0 seconds'``.
 
@@ -912,26 +1021,41 @@ class DataStreamWriter(object):
                 raise ValueError('Value for processingTime must be a non empty 
string. Got: %s' %
                                  processingTime)
             interval = processingTime.strip()
-            jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime(
-                interval)
+            jTrigger = (
+                self._spark._sc  # type: ignore[attr-defined]
+                
._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime(interval)
+            )
 
         elif once is not None:
             if once is not True:
                 raise ValueError('Value for once must be True. Got: %s' % once)
-            jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once()
+            jTrigger = (
+                self._spark._sc  # type: ignore[attr-defined]
+                ._jvm.org.apache.spark.sql.streaming.Trigger.Once()
+            )
 
         else:
             if type(continuous) != str or len(continuous.strip()) == 0:
                 raise ValueError('Value for continuous must be a non empty 
string. Got: %s' %
                                  continuous)
             interval = continuous.strip()
-            jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous(
-                interval)
+            jTrigger = (
+                self._spark._sc  # type: ignore[attr-defined]
+                
._jvm.org.apache.spark.sql.streaming.Trigger.Continuous(interval)
+            )
 
         self._jwrite = self._jwrite.trigger(jTrigger)
         return self
 
-    def foreach(self, f):
+    @overload
+    def foreach(self, f: Callable[[Row], None]) -> "DataStreamWriter":
+        ...
+
+    @overload
+    def foreach(self, f: "SupportsProcess") -> "DataStreamWriter":
+        ...
+
+    def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> 
"DataStreamWriter":
         """
         Sets the output of the streaming query to be processed using the 
provided writer ``f``.
         This is often used to write the output of a streaming query to 
arbitrary storage systems.
@@ -1024,7 +1148,7 @@ class DataStreamWriter(object):
         >>> writer = sdf.writeStream.foreach(RowPrinter())
         """
 
-        from pyspark.rdd import _wrap_function
+        from pyspark.rdd import _wrap_function  # type: ignore[attr-defined]
         from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
         from pyspark.taskcontext import TaskContext
 
@@ -1032,9 +1156,9 @@ class DataStreamWriter(object):
             # The provided object is a callable function that is supposed to 
be called on each row.
             # Construct a function that takes an iterator and calls the 
provided function on each
             # row.
-            def func_without_process(_, iterator):
+            def func_without_process(_: Any, iterator: Iterator) -> Iterator:
                 for x in iterator:
-                    f(x)
+                    f(x)  # type: ignore[operator]
                 return iter([])
 
             func = func_without_process
@@ -1050,7 +1174,7 @@ class DataStreamWriter(object):
             if not callable(getattr(f, 'process')):
                 raise TypeError("Attribute 'process' in provided object is not 
callable")
 
-            def doesMethodExist(method_name):
+            def doesMethodExist(method_name: str) -> bool:
                 exists = hasattr(f, method_name)
                 if exists and not callable(getattr(f, method_name)):
                     raise TypeError(
@@ -1060,45 +1184,51 @@ class DataStreamWriter(object):
             open_exists = doesMethodExist('open')
             close_exists = doesMethodExist('close')
 
-            def func_with_open_process_close(partition_id, iterator):
+            def func_with_open_process_close(partition_id: Any, iterator: 
Iterator) -> Iterator:
                 epoch_id = 
TaskContext.get().getLocalProperty('streaming.sql.batchId')
                 if epoch_id:
-                    epoch_id = int(epoch_id)
+                    int_epoch_id = int(epoch_id)
                 else:
                     raise RuntimeError("Could not get batch id from 
TaskContext")
 
                 # Check if the data should be processed
                 should_process = True
                 if open_exists:
-                    should_process = f.open(partition_id, epoch_id)
+                    should_process = f.open(partition_id, int_epoch_id)   # 
type: ignore[union-attr]
 
                 error = None
 
                 try:
                     if should_process:
                         for x in iterator:
-                            f.process(x)
+                            cast("SupportsProcess", f).process(x)
                 except Exception as ex:
                     error = ex
                 finally:
                     if close_exists:
-                        f.close(error)
+                        f.close(error)  # type: ignore[union-attr]
                     if error:
                         raise error
 
                 return iter([])
 
-            func = func_with_open_process_close
+            func = func_with_open_process_close  # type: ignore[assignment]
 
         serializer = AutoBatchedSerializer(PickleSerializer())
-        wrapped_func = _wrap_function(self._spark._sc, func, serializer, 
serializer)
-        jForeachWriter = \
-            
self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
+        wrapped_func = _wrap_function(
+            self._spark._sc,  # type: ignore[attr-defined]
+            func, serializer, serializer)
+        jForeachWriter = (
+            self._spark._sc    # type: ignore[attr-defined]
+            ._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
                 wrapped_func, self._df._jdf.schema())
+        )
         self._jwrite.foreach(jForeachWriter)
         return self
 
-    def foreachBatch(self, func):
+    def foreachBatch(
+        self, func: Callable[["DataFrame", int], None]
+    ) -> "DataStreamWriter":
         """
         Sets the output of the streaming query to be processed using the 
provided
         function. This is supported only the in the micro-batch execution 
modes (that is, when the
@@ -1124,7 +1254,7 @@ class DataStreamWriter(object):
         """
 
         from pyspark.java_gateway import ensure_callback_server_started
-        gw = self._spark._sc._gateway
+        gw = self._spark._sc._gateway  # type: ignore[attr-defined]
         java_import(gw.jvm, 
"org.apache.spark.sql.execution.streaming.sources.*")
 
         wrapped_func = ForeachBatchFunction(self._spark, func)
@@ -1132,8 +1262,15 @@ class DataStreamWriter(object):
         ensure_callback_server_started(gw)
         return self
 
-    def start(self, path=None, format=None, outputMode=None, partitionBy=None, 
queryName=None,
-              **options):
+    def start(
+        self,
+        path: Optional[str] = None,
+        format: Optional[str] = None,
+        outputMode: Optional[str] = None,
+        partitionBy: Optional[Union[str, List[str]]] = None,
+        queryName: Optional[str] = None,
+        **options: "OptionalPrimitiveType",
+    ) -> StreamingQuery:
         """Streams the contents of the :class:`DataFrame` to a data source.
 
         The data source is specified by the ``format`` and a set of 
``options``.
@@ -1203,8 +1340,15 @@ class DataStreamWriter(object):
         else:
             return self._sq(self._jwrite.start(path))
 
-    def toTable(self, tableName, format=None, outputMode=None, 
partitionBy=None, queryName=None,
-                **options):
+    def toTable(
+        self,
+        tableName: str,
+        format: Optional[str] = None,
+        outputMode: Optional[str] = None,
+        partitionBy: Optional[Union[str, List[str]]] = None,
+        queryName: Optional[str] = None,
+        **options: "OptionalPrimitiveType"
+    ) -> StreamingQuery:
         """
         Starts the execution of the streaming query, which will continually 
output results to the
         given table as new data arrives.
@@ -1274,20 +1418,21 @@ class DataStreamWriter(object):
         return self._sq(self._jwrite.toTable(tableName))
 
 
-def _test():
+def _test() -> None:
     import doctest
     import os
     import tempfile
     from pyspark.sql import SparkSession, SQLContext
     import pyspark.sql.streaming
+    from py4j.protocol import Py4JError
 
     os.chdir(os.environ["SPARK_HOME"])
 
     globs = pyspark.sql.streaming.__dict__.copy()
     try:
         spark = SparkSession.builder.getOrCreate()
-    except py4j.protocol.Py4JError:  # noqa: F821
-        spark = SparkSession(sc)  # noqa: F821
+    except Py4JError:  # noqa: F821
+        spark = SparkSession(sc)  # type: ignore[name-defined] # noqa: F821
 
     globs['tempfile'] = tempfile
     globs['os'] = os
diff --git a/python/pyspark/sql/streaming.pyi b/python/pyspark/sql/streaming.pyi
deleted file mode 100644
index 1d05483..0000000
--- a/python/pyspark/sql/streaming.pyi
+++ /dev/null
@@ -1,197 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import Any, Callable, Dict, List, Optional, Union
-
-from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType
-from pyspark.sql.context import SQLContext
-from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.readwriter import OptionUtils
-from pyspark.sql.types import Row, StructType
-from pyspark.sql.utils import StreamingQueryException
-
-from py4j.java_gateway import JavaObject  # type: ignore[import]
-
-class StreamingQuery:
-    def __init__(self, jsq: JavaObject) -> None: ...
-    @property
-    def id(self) -> str: ...
-    @property
-    def runId(self) -> str: ...
-    @property
-    def name(self) -> str: ...
-    @property
-    def isActive(self) -> bool: ...
-    def awaitTermination(self, timeout: Optional[int] = ...) -> 
Optional[bool]: ...
-    @property
-    def status(self) -> Dict[str, Any]: ...
-    @property
-    def recentProgress(self) -> List[Dict[str, Any]]: ...
-    @property
-    def lastProgress(self) -> Optional[Dict[str, Any]]: ...
-    def processAllAvailable(self) -> None: ...
-    def stop(self) -> None: ...
-    def explain(self, extended: bool = ...) -> None: ...
-    def exception(self) -> Optional[StreamingQueryException]: ...
-
-class StreamingQueryManager:
-    def __init__(self, jsqm: JavaObject) -> None: ...
-    @property
-    def active(self) -> List[StreamingQuery]: ...
-    def get(self, id: str) -> StreamingQuery: ...
-    def awaitAnyTermination(self, timeout: Optional[int] = ...) -> bool: ...
-    def resetTerminated(self) -> None: ...
-
-class DataStreamReader(OptionUtils):
-    def __init__(self, spark: SQLContext) -> None: ...
-    def format(self, source: str) -> DataStreamReader: ...
-    def schema(self, schema: Union[StructType, str]) -> DataStreamReader: ...
-    def option(self, key: str, value: OptionalPrimitiveType) -> 
DataStreamReader: ...
-    def options(self, **options: OptionalPrimitiveType) -> DataStreamReader: 
...
-    def load(
-        self,
-        path: Optional[str] = ...,
-        format: Optional[str] = ...,
-        schema: Optional[Union[StructType, str]] = ...,
-        **options: OptionalPrimitiveType
-    ) -> DataFrame: ...
-    def json(
-        self,
-        path: str,
-        schema: Optional[Union[StructType, str]] = ...,
-        primitivesAsString: Optional[Union[bool, str]] = ...,
-        prefersDecimal: Optional[Union[bool, str]] = ...,
-        allowComments: Optional[Union[bool, str]] = ...,
-        allowUnquotedFieldNames: Optional[Union[bool, str]] = ...,
-        allowSingleQuotes: Optional[Union[bool, str]] = ...,
-        allowNumericLeadingZero: Optional[Union[bool, str]] = ...,
-        allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ...,
-        mode: Optional[str] = ...,
-        columnNameOfCorruptRecord: Optional[str] = ...,
-        dateFormat: Optional[str] = ...,
-        timestampFormat: Optional[str] = ...,
-        multiLine: Optional[Union[bool, str]] = ...,
-        allowUnquotedControlChars: Optional[Union[bool, str]] = ...,
-        lineSep: Optional[str] = ...,
-        locale: Optional[str] = ...,
-        dropFieldIfAllNull: Optional[Union[bool, str]] = ...,
-        encoding: Optional[str] = ...,
-        pathGlobFilter: Optional[Union[bool, str]] = ...,
-        recursiveFileLookup: Optional[Union[bool, str]] = ...,
-        allowNonNumericNumbers: Optional[Union[bool, str]] = ...,
-    ) -> DataFrame: ...
-    def orc(
-        self,
-        path: str,
-        mergeSchema: Optional[bool] = ...,
-        pathGlobFilter: Optional[Union[bool, str]] = ...,
-        recursiveFileLookup: Optional[Union[bool, str]] = ...,
-    ) -> DataFrame: ...
-    def parquet(
-        self,
-        path: str,
-        mergeSchema: Optional[bool] = ...,
-        pathGlobFilter: Optional[Union[bool, str]] = ...,
-        recursiveFileLookup: Optional[Union[bool, str]] = ...,
-    ) -> DataFrame: ...
-    def text(
-        self,
-        path: str,
-        wholetext: bool = ...,
-        lineSep: Optional[str] = ...,
-        pathGlobFilter: Optional[Union[bool, str]] = ...,
-        recursiveFileLookup: Optional[Union[bool, str]] = ...,
-    ) -> DataFrame: ...
-    def csv(
-        self,
-        path: str,
-        schema: Optional[Union[StructType, str]] = ...,
-        sep: Optional[str] = ...,
-        encoding: Optional[str] = ...,
-        quote: Optional[str] = ...,
-        escape: Optional[str] = ...,
-        comment: Optional[str] = ...,
-        header: Optional[Union[bool, str]] = ...,
-        inferSchema: Optional[Union[bool, str]] = ...,
-        ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ...,
-        ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ...,
-        nullValue: Optional[str] = ...,
-        nanValue: Optional[str] = ...,
-        positiveInf: Optional[str] = ...,
-        negativeInf: Optional[str] = ...,
-        dateFormat: Optional[str] = ...,
-        timestampFormat: Optional[str] = ...,
-        maxColumns: Optional[Union[int, str]] = ...,
-        maxCharsPerColumn: Optional[Union[int, str]] = ...,
-        mode: Optional[str] = ...,
-        columnNameOfCorruptRecord: Optional[str] = ...,
-        multiLine: Optional[Union[bool, str]] = ...,
-        charToEscapeQuoteEscaping: Optional[Union[bool, str]] = ...,
-        enforceSchema: Optional[Union[bool, str]] = ...,
-        emptyValue: Optional[str] = ...,
-        locale: Optional[str] = ...,
-        lineSep: Optional[str] = ...,
-        pathGlobFilter: Optional[Union[bool, str]] = ...,
-        recursiveFileLookup: Optional[Union[bool, str]] = ...,
-        unescapedQuoteHandling: Optional[str] = ...,
-    ) -> DataFrame: ...
-    def table(self, tableName: str) -> DataFrame: ...
-
-class DataStreamWriter:
-    def __init__(self, df: DataFrame) -> None: ...
-    def outputMode(self, outputMode: str) -> DataStreamWriter: ...
-    def format(self, source: str) -> DataStreamWriter: ...
-    def option(self, key: str, value: OptionalPrimitiveType) -> 
DataStreamWriter: ...
-    def options(self, **options: OptionalPrimitiveType) -> DataStreamWriter: 
...
-    @overload
-    def partitionBy(self, *cols: str) -> DataStreamWriter: ...
-    @overload
-    def partitionBy(self, __cols: List[str]) -> DataStreamWriter: ...
-    def queryName(self, queryName: str) -> DataStreamWriter: ...
-    @overload
-    def trigger(self, processingTime: str) -> DataStreamWriter: ...
-    @overload
-    def trigger(self, once: bool) -> DataStreamWriter: ...
-    @overload
-    def trigger(self, continuous: bool) -> DataStreamWriter: ...
-    def start(
-        self,
-        path: Optional[str] = ...,
-        format: Optional[str] = ...,
-        outputMode: Optional[str] = ...,
-        partitionBy: Optional[Union[str, List[str]]] = ...,
-        queryName: Optional[str] = ...,
-        **options: OptionalPrimitiveType
-    ) -> StreamingQuery: ...
-    @overload
-    def foreach(self, f: Callable[[Row], None]) -> DataStreamWriter: ...
-    @overload
-    def foreach(self, f: SupportsProcess) -> DataStreamWriter: ...
-    def foreachBatch(
-        self, func: Callable[[DataFrame, int], None]
-    ) -> DataStreamWriter: ...
-    def toTable(
-        self,
-        tableName: str,
-        format: Optional[str] = ...,
-        outputMode: Optional[str] = ...,
-        partitionBy: Optional[Union[str, List[str]]] = ...,
-        queryName: Optional[str] = ...,
-        **options: OptionalPrimitiveType
-    ) -> StreamingQuery: ...

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to