This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e3b4d1e [SPARK-36991][PYTHON] Inline type hints for spark/python/pyspark/sql/streaming.py e3b4d1e is described below commit e3b4d1eef37704f8ebb5bc669ad041166cc674f0 Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Fri Oct 15 12:05:08 2021 -0700 [SPARK-36991][PYTHON] Inline type hints for spark/python/pyspark/sql/streaming.py ### What changes were proposed in this pull request? Inline type hints for spark/python/pyspark/sql/streaming.py ### Why are the changes needed? Current stub files cannot support type checking within function bodies. So we wanted to inline type hints for spark/python/pyspark/sql/streaming.py. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test. Closes #34277 from xinrong-databricks/inline_streaming. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/streaming.py | 355 +++++++++++++++++++++++++++------------ python/pyspark/sql/streaming.pyi | 197 ---------------------- 2 files changed, 250 insertions(+), 302 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 08c8934..24cd2db 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -17,15 +17,22 @@ import sys import json +from collections.abc import Iterator +from typing import cast, overload, Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union -from py4j.java_gateway import java_import +from py4j.java_gateway import java_import, JavaObject -from pyspark import since, keyword_only +from pyspark import since from pyspark.sql.column import _to_seq -from pyspark.sql.readwriter import OptionUtils, to_str -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.readwriter import OptionUtils, to_str # type: ignore[attr-defined] +from pyspark.sql.types import Row, StructType, StructField, StringType from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException +if TYPE_CHECKING: + from pyspark.sql import SQLContext + from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType + from pyspark.sql.dataframe import DataFrame + __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -41,12 +48,12 @@ class StreamingQuery(object): This API is evolving. """ - def __init__(self, jsq): + def __init__(self, jsq: JavaObject) -> None: self._jsq = jsq - @property + @property # type: ignore[misc] @since(2.0) - def id(self): + def id(self) -> str: """Returns the unique id of this query that persists across restarts from checkpoint data. That is, this id is generated when a query is started for the first time, and will be the same every time it is restarted from checkpoint data. @@ -55,17 +62,17 @@ class StreamingQuery(object): """ return self._jsq.id().toString() - @property + @property # type: ignore[misc] @since(2.1) - def runId(self): + def runId(self) -> str: """Returns the unique id of this query that does not persist across restarts. That is, every query that is started (or restarted from checkpoint) will have a different runId. """ return self._jsq.runId().toString() - @property + @property # type: ignore[misc] @since(2.0) - def name(self): + def name(self) -> str: """Returns the user-specified name of the query, or null if not specified. This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as `dataframe.writeStream.queryName("query").start()`. @@ -73,15 +80,15 @@ class StreamingQuery(object): """ return self._jsq.name() - @property + @property # type: ignore[misc] @since(2.0) - def isActive(self): + def isActive(self) -> bool: """Whether this streaming query is currently active or not. """ return self._jsq.isActive() @since(2.0) - def awaitTermination(self, timeout=None): + def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: """Waits for the termination of `this` query, either by :func:`query.stop()` or by an exception. If the query has terminated with an exception, then the exception will be thrown. If `timeout` is set, it returns whether the query has terminated or not within the @@ -100,17 +107,17 @@ class StreamingQuery(object): else: return self._jsq.awaitTermination() - @property + @property # type: ignore[misc] @since(2.1) - def status(self): + def status(self) -> Dict[str, Any]: """ Returns the current status of the query. """ return json.loads(self._jsq.status().json()) - @property + @property # type: ignore[misc] @since(2.1) - def recentProgress(self): + def recentProgress(self) -> List[Dict[str, Any]]: """Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The number of progress updates retained for each stream is configured by Spark session configuration `spark.sql.streaming.numRecentProgressUpdates`. @@ -118,7 +125,7 @@ class StreamingQuery(object): return [json.loads(p.json()) for p in self._jsq.recentProgress()] @property - def lastProgress(self): + def lastProgress(self) -> Optional[Dict[str, Any]]: """ Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or None if there were no progress updates @@ -135,7 +142,7 @@ class StreamingQuery(object): else: return None - def processAllAvailable(self): + def processAllAvailable(self) -> None: """Blocks until all available data in the source has been processed and committed to the sink. This method is intended for testing. @@ -151,12 +158,12 @@ class StreamingQuery(object): return self._jsq.processAllAvailable() @since(2.0) - def stop(self): + def stop(self) -> None: """Stop this streaming query. """ self._jsq.stop() - def explain(self, extended=False): + def explain(self, extended: bool = False) -> None: """Prints the (logical and physical) plans to the console for debugging purpose. .. versionadded:: 2.1.0 @@ -188,7 +195,7 @@ class StreamingQuery(object): # We should print it in the Python process. print(self._jsq.explainInternal(extended)) - def exception(self): + def exception(self) -> Optional[StreamingQueryException]: """ .. versionadded:: 2.1.0 @@ -216,11 +223,11 @@ class StreamingQueryManager(object): This API is evolving. """ - def __init__(self, jsqm): + def __init__(self, jsqm: JavaObject) -> None: self._jsqm = jsqm @property - def active(self): + def active(self) -> List[StreamingQuery]: """Returns a list of active queries associated with this SQLContext .. versionadded:: 2.0.0 @@ -236,7 +243,7 @@ class StreamingQueryManager(object): """ return [StreamingQuery(jsq) for jsq in self._jsqm.active()] - def get(self, id): + def get(self, id: str) -> StreamingQuery: """Returns an active query from this SQLContext or throws exception if an active query with this name doesn't exist. @@ -258,7 +265,7 @@ class StreamingQueryManager(object): return StreamingQuery(self._jsqm.get(id)) @since(2.0) - def awaitAnyTermination(self, timeout=None): + def awaitAnyTermination(self, timeout: Optional[int] = None) -> Optional[bool]: """Wait until any of the queries on the associated SQLContext has terminated since the creation of the context, or since :func:`resetTerminated()` was called. If any query was terminated with an exception, then the exception will be thrown. @@ -285,7 +292,7 @@ class StreamingQueryManager(object): else: return self._jsqm.awaitAnyTermination() - def resetTerminated(self): + def resetTerminated(self) -> None: """Forget about past terminated queries so that :func:`awaitAnyTermination()` can be used again to wait for new terminations. @@ -311,15 +318,15 @@ class DataStreamReader(OptionUtils): This API is evolving. """ - def __init__(self, spark): - self._jreader = spark._ssql_ctx.readStream() + def __init__(self, spark: "SQLContext") -> None: + self._jreader = spark._ssql_ctx.readStream() # type: ignore[attr-defined] self._spark = spark - def _df(self, jdf): + def _df(self, jdf: JavaObject) -> "DataFrame": from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._spark) - def format(self, source): + def format(self, source: str) -> "DataStreamReader": """Specifies the input data source format. .. versionadded:: 2.0.0 @@ -340,7 +347,7 @@ class DataStreamReader(OptionUtils): self._jreader = self._jreader.format(source) return self - def schema(self, schema): + def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": """Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema automatically from data. @@ -375,7 +382,7 @@ class DataStreamReader(OptionUtils): raise TypeError("schema should be StructType or string") return self - def option(self, key, value): + def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader": """Adds an input option for the underlying data source. .. versionadded:: 2.0.0 @@ -391,7 +398,7 @@ class DataStreamReader(OptionUtils): self._jreader = self._jreader.option(key, to_str(value)) return self - def options(self, **options): + def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": """Adds input options for the underlying data source. .. versionadded:: 2.0.0 @@ -408,7 +415,13 @@ class DataStreamReader(OptionUtils): self._jreader = self._jreader.option(k, to_str(options[k])) return self - def load(self, path=None, format=None, schema=None, **options): + def load( + self, + path: Optional[str] = None, + format: Optional[str] = None, + schema: Optional[Union[StructType, str]] = None, + **options: "OptionalPrimitiveType" + ) -> "DataFrame": """Loads a data stream from a data source and returns it as a :class:`DataFrame <pyspark.sql.DataFrame>`. @@ -453,13 +466,31 @@ class DataStreamReader(OptionUtils): else: return self._df(self._jreader.load()) - def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, - allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, - allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, - mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None, - dropFieldIfAllNull=None, encoding=None, pathGlobFilter=None, - recursiveFileLookup=None, allowNonNumericNumbers=None): + def json( + self, + path: str, + schema: Optional[Union[StructType, str]] = None, + primitivesAsString: Optional[Union[bool, str]] = None, + prefersDecimal: Optional[Union[bool, str]] = None, + allowComments: Optional[Union[bool, str]] = None, + allowUnquotedFieldNames: Optional[Union[bool, str]] = None, + allowSingleQuotes: Optional[Union[bool, str]] = None, + allowNumericLeadingZero: Optional[Union[bool, str]] = None, + allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None, + mode: Optional[str] = None, + columnNameOfCorruptRecord: Optional[str] = None, + dateFormat: Optional[str] = None, + timestampFormat: Optional[str] = None, + multiLine: Optional[Union[bool, str]] = None, + allowUnquotedControlChars: Optional[Union[bool, str]] = None, + lineSep: Optional[str] = None, + locale: Optional[str] = None, + dropFieldIfAllNull: Optional[Union[bool, str]] = None, + encoding: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + allowNonNumericNumbers: Optional[Union[bool, str]] = None, + ) -> "DataFrame": """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -501,7 +532,7 @@ class DataStreamReader(OptionUtils): >>> json_sdf.schema == sdf_schema True """ - self._set_opts( + self._set_opts( # type: ignore[attr-defined] schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, @@ -517,7 +548,13 @@ class DataStreamReader(OptionUtils): else: raise TypeError("path can be only a single string") - def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None): + def orc( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": """Loads a ORC file stream, returning the result as a :class:`DataFrame`. .. versionadded:: 2.3.0 @@ -539,15 +576,25 @@ class DataStreamReader(OptionUtils): >>> orc_sdf.schema == sdf_schema True """ - self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup) + self._set_opts( # type: ignore[attr-defined] + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup + ) if isinstance(path, str): return self._df(self._jreader.orc(path)) else: raise TypeError("path can be only a single string") - def parquet(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=None, - datetimeRebaseMode=None, int96RebaseMode=None): + def parquet( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + datetimeRebaseMode: Optional[Union[bool, str]] = None, + int96RebaseMode: Optional[Union[bool, str]] = None, + ) -> "DataFrame": """ Loads a Parquet file stream, returning the result as a :class:`DataFrame`. @@ -575,16 +622,26 @@ class DataStreamReader(OptionUtils): >>> parquet_sdf.schema == sdf_schema True """ - self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, - recursiveFileLookup=recursiveFileLookup, - datetimeRebaseMode=datetimeRebaseMode, int96RebaseMode=int96RebaseMode) + self._set_opts( # type: ignore[attr-defined] + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode + ) if isinstance(path, str): return self._df(self._jreader.parquet(path)) else: raise TypeError("path can be only a single string") - def text(self, path, wholetext=False, lineSep=None, pathGlobFilter=None, - recursiveFileLookup=None): + def text( + self, + path: str, + wholetext: bool = False, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -597,7 +654,7 @@ class DataStreamReader(OptionUtils): Parameters ---------- - paths : str or list + path : str or list string, or list of strings, for input path(s). Other Parameters @@ -621,7 +678,7 @@ class DataStreamReader(OptionUtils): >>> "value" in str(text_sdf.schema) True """ - self._set_opts( + self._set_opts( # type: ignore[attr-defined] wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) if isinstance(path, str): @@ -629,14 +686,40 @@ class DataStreamReader(OptionUtils): else: raise TypeError("path can be only a single string") - def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, - comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, - ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, - negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None, locale=None, lineSep=None, - pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHandling=None): + def csv( + self, + path: str, + schema: Optional[Union[StructType, str]] = None, + sep: Optional[str] = None, + encoding: Optional[str] = None, + quote: Optional[str] = None, + escape: Optional[str] = None, + comment: Optional[str] = None, + header: Optional[Union[bool, str]] = None, + inferSchema: Optional[Union[bool, str]] = None, + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None, + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None, + nullValue: Optional[str] = None, + nanValue: Optional[str] = None, + positiveInf: Optional[str] = None, + negativeInf: Optional[str] = None, + dateFormat: Optional[str] = None, + timestampFormat: Optional[str] = None, + maxColumns: Optional[Union[int, str]] = None, + maxCharsPerColumn: Optional[Union[int, str]] = None, + maxMalformedLogPerPartition: Optional[Union[int, str]] = None, + mode: Optional[str] = None, + columnNameOfCorruptRecord: Optional[str] = None, + multiLine: Optional[Union[bool, str]] = None, + charToEscapeQuoteEscaping: Optional[Union[bool, str]] = None, + enforceSchema: Optional[Union[bool, str]] = None, + emptyValue: Optional[str] = None, + locale: Optional[str] = None, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + unescapedQuoteHandling: Optional[str] = None, + ) -> "DataFrame": r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -674,7 +757,7 @@ class DataStreamReader(OptionUtils): >>> csv_sdf.schema == sdf_schema True """ - self._set_opts( + self._set_opts( # type: ignore[attr-defined] schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, @@ -692,7 +775,7 @@ class DataStreamReader(OptionUtils): else: raise TypeError("path can be only a single string") - def table(self, tableName): + def table(self, tableName: str) -> "DataFrame": """Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should support streaming mode. @@ -735,16 +818,16 @@ class DataStreamWriter(object): This API is evolving. """ - def __init__(self, df): + def __init__(self, df: "DataFrame") -> None: self._df = df self._spark = df.sql_ctx self._jwrite = df._jdf.writeStream() - def _sq(self, jsq): + def _sq(self, jsq: JavaObject) -> StreamingQuery: from pyspark.sql.streaming import StreamingQuery return StreamingQuery(jsq) - def outputMode(self, outputMode): + def outputMode(self, outputMode: str) -> "DataStreamWriter": """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. .. versionadded:: 2.0.0 @@ -772,7 +855,7 @@ class DataStreamWriter(object): self._jwrite = self._jwrite.outputMode(outputMode) return self - def format(self, source): + def format(self, source: str) -> "DataStreamWriter": """Specifies the underlying output data source. .. versionadded:: 2.0.0 @@ -793,7 +876,7 @@ class DataStreamWriter(object): self._jwrite = self._jwrite.format(source) return self - def option(self, key, value): + def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter": """Adds an output option for the underlying data source. .. versionadded:: 2.0.0 @@ -805,7 +888,7 @@ class DataStreamWriter(object): self._jwrite = self._jwrite.option(key, to_str(value)) return self - def options(self, **options): + def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": """Adds output options for the underlying data source. .. versionadded:: 2.0.0 @@ -818,7 +901,15 @@ class DataStreamWriter(object): self._jwrite = self._jwrite.option(k, to_str(options[k])) return self - def partitionBy(self, *cols): + @overload + def partitionBy(self, *cols: str) -> "DataStreamWriter": + ... + + @overload + def partitionBy(self, __cols: List[str]) -> "DataStreamWriter": + ... + + def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] """Partitions the output by the given columns on the file system. If specified, the output is laid out on the file system similar @@ -837,10 +928,11 @@ class DataStreamWriter(object): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) + self._jwrite = self._jwrite.partitionBy( + _to_seq(self._spark._sc, cols)) # type: ignore[attr-defined] return self - def queryName(self, queryName): + def queryName(self, queryName: str) -> "DataStreamWriter": """Specifies the name of the :class:`StreamingQuery` that can be started with :func:`start`. This name must be unique among all the currently active queries in the associated SparkSession. @@ -865,8 +957,25 @@ class DataStreamWriter(object): self._jwrite = self._jwrite.queryName(queryName) return self - @keyword_only - def trigger(self, *, processingTime=None, once=None, continuous=None): + @overload + def trigger(self, *, processingTime: str) -> "DataStreamWriter": + ... + + @overload + def trigger(self, *, once: bool) -> "DataStreamWriter": + ... + + @overload + def trigger(self, *, continuous: str) -> "DataStreamWriter": + ... + + def trigger( + self, + *, + processingTime: Optional[str] = None, + once: Optional[bool] = None, + continuous: Optional[str] = None, + ) -> "DataStreamWriter": """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -912,26 +1021,41 @@ class DataStreamWriter(object): raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() - jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( - interval) + jTrigger = ( + self._spark._sc # type: ignore[attr-defined] + ._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime(interval) + ) elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) - jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + jTrigger = ( + self._spark._sc # type: ignore[attr-defined] + ._jvm.org.apache.spark.sql.streaming.Trigger.Once() + ) else: if type(continuous) != str or len(continuous.strip()) == 0: raise ValueError('Value for continuous must be a non empty string. Got: %s' % continuous) interval = continuous.strip() - jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( - interval) + jTrigger = ( + self._spark._sc # type: ignore[attr-defined] + ._jvm.org.apache.spark.sql.streaming.Trigger.Continuous(interval) + ) self._jwrite = self._jwrite.trigger(jTrigger) return self - def foreach(self, f): + @overload + def foreach(self, f: Callable[[Row], None]) -> "DataStreamWriter": + ... + + @overload + def foreach(self, f: "SupportsProcess") -> "DataStreamWriter": + ... + + def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataStreamWriter": """ Sets the output of the streaming query to be processed using the provided writer ``f``. This is often used to write the output of a streaming query to arbitrary storage systems. @@ -1024,7 +1148,7 @@ class DataStreamWriter(object): >>> writer = sdf.writeStream.foreach(RowPrinter()) """ - from pyspark.rdd import _wrap_function + from pyspark.rdd import _wrap_function # type: ignore[attr-defined] from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.taskcontext import TaskContext @@ -1032,9 +1156,9 @@ class DataStreamWriter(object): # The provided object is a callable function that is supposed to be called on each row. # Construct a function that takes an iterator and calls the provided function on each # row. - def func_without_process(_, iterator): + def func_without_process(_: Any, iterator: Iterator) -> Iterator: for x in iterator: - f(x) + f(x) # type: ignore[operator] return iter([]) func = func_without_process @@ -1050,7 +1174,7 @@ class DataStreamWriter(object): if not callable(getattr(f, 'process')): raise TypeError("Attribute 'process' in provided object is not callable") - def doesMethodExist(method_name): + def doesMethodExist(method_name: str) -> bool: exists = hasattr(f, method_name) if exists and not callable(getattr(f, method_name)): raise TypeError( @@ -1060,45 +1184,51 @@ class DataStreamWriter(object): open_exists = doesMethodExist('open') close_exists = doesMethodExist('close') - def func_with_open_process_close(partition_id, iterator): + def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Iterator: epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') if epoch_id: - epoch_id = int(epoch_id) + int_epoch_id = int(epoch_id) else: raise RuntimeError("Could not get batch id from TaskContext") # Check if the data should be processed should_process = True if open_exists: - should_process = f.open(partition_id, epoch_id) + should_process = f.open(partition_id, int_epoch_id) # type: ignore[union-attr] error = None try: if should_process: for x in iterator: - f.process(x) + cast("SupportsProcess", f).process(x) except Exception as ex: error = ex finally: if close_exists: - f.close(error) + f.close(error) # type: ignore[union-attr] if error: raise error return iter([]) - func = func_with_open_process_close + func = func_with_open_process_close # type: ignore[assignment] serializer = AutoBatchedSerializer(PickleSerializer()) - wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) - jForeachWriter = \ - self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func = _wrap_function( + self._spark._sc, # type: ignore[attr-defined] + func, serializer, serializer) + jForeachWriter = ( + self._spark._sc # type: ignore[attr-defined] + ._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( wrapped_func, self._df._jdf.schema()) + ) self._jwrite.foreach(jForeachWriter) return self - def foreachBatch(self, func): + def foreachBatch( + self, func: Callable[["DataFrame", int], None] + ) -> "DataStreamWriter": """ Sets the output of the streaming query to be processed using the provided function. This is supported only the in the micro-batch execution modes (that is, when the @@ -1124,7 +1254,7 @@ class DataStreamWriter(object): """ from pyspark.java_gateway import ensure_callback_server_started - gw = self._spark._sc._gateway + gw = self._spark._sc._gateway # type: ignore[attr-defined] java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") wrapped_func = ForeachBatchFunction(self._spark, func) @@ -1132,8 +1262,15 @@ class DataStreamWriter(object): ensure_callback_server_started(gw) return self - def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, - **options): + def start( + self, + path: Optional[str] = None, + format: Optional[str] = None, + outputMode: Optional[str] = None, + partitionBy: Optional[Union[str, List[str]]] = None, + queryName: Optional[str] = None, + **options: "OptionalPrimitiveType", + ) -> StreamingQuery: """Streams the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -1203,8 +1340,15 @@ class DataStreamWriter(object): else: return self._sq(self._jwrite.start(path)) - def toTable(self, tableName, format=None, outputMode=None, partitionBy=None, queryName=None, - **options): + def toTable( + self, + tableName: str, + format: Optional[str] = None, + outputMode: Optional[str] = None, + partitionBy: Optional[Union[str, List[str]]] = None, + queryName: Optional[str] = None, + **options: "OptionalPrimitiveType" + ) -> StreamingQuery: """ Starts the execution of the streaming query, which will continually output results to the given table as new data arrives. @@ -1274,20 +1418,21 @@ class DataStreamWriter(object): return self._sq(self._jwrite.toTable(tableName)) -def _test(): +def _test() -> None: import doctest import os import tempfile from pyspark.sql import SparkSession, SQLContext import pyspark.sql.streaming + from py4j.protocol import Py4JError os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.streaming.__dict__.copy() try: spark = SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: # noqa: F821 - spark = SparkSession(sc) # noqa: F821 + except Py4JError: # noqa: F821 + spark = SparkSession(sc) # type: ignore[name-defined] # noqa: F821 globs['tempfile'] = tempfile globs['os'] = os diff --git a/python/pyspark/sql/streaming.pyi b/python/pyspark/sql/streaming.pyi deleted file mode 100644 index 1d05483..0000000 --- a/python/pyspark/sql/streaming.pyi +++ /dev/null @@ -1,197 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Any, Callable, Dict, List, Optional, Union - -from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType -from pyspark.sql.context import SQLContext -from pyspark.sql.dataframe import DataFrame -from pyspark.sql.readwriter import OptionUtils -from pyspark.sql.types import Row, StructType -from pyspark.sql.utils import StreamingQueryException - -from py4j.java_gateway import JavaObject # type: ignore[import] - -class StreamingQuery: - def __init__(self, jsq: JavaObject) -> None: ... - @property - def id(self) -> str: ... - @property - def runId(self) -> str: ... - @property - def name(self) -> str: ... - @property - def isActive(self) -> bool: ... - def awaitTermination(self, timeout: Optional[int] = ...) -> Optional[bool]: ... - @property - def status(self) -> Dict[str, Any]: ... - @property - def recentProgress(self) -> List[Dict[str, Any]]: ... - @property - def lastProgress(self) -> Optional[Dict[str, Any]]: ... - def processAllAvailable(self) -> None: ... - def stop(self) -> None: ... - def explain(self, extended: bool = ...) -> None: ... - def exception(self) -> Optional[StreamingQueryException]: ... - -class StreamingQueryManager: - def __init__(self, jsqm: JavaObject) -> None: ... - @property - def active(self) -> List[StreamingQuery]: ... - def get(self, id: str) -> StreamingQuery: ... - def awaitAnyTermination(self, timeout: Optional[int] = ...) -> bool: ... - def resetTerminated(self) -> None: ... - -class DataStreamReader(OptionUtils): - def __init__(self, spark: SQLContext) -> None: ... - def format(self, source: str) -> DataStreamReader: ... - def schema(self, schema: Union[StructType, str]) -> DataStreamReader: ... - def option(self, key: str, value: OptionalPrimitiveType) -> DataStreamReader: ... - def options(self, **options: OptionalPrimitiveType) -> DataStreamReader: ... - def load( - self, - path: Optional[str] = ..., - format: Optional[str] = ..., - schema: Optional[Union[StructType, str]] = ..., - **options: OptionalPrimitiveType - ) -> DataFrame: ... - def json( - self, - path: str, - schema: Optional[Union[StructType, str]] = ..., - primitivesAsString: Optional[Union[bool, str]] = ..., - prefersDecimal: Optional[Union[bool, str]] = ..., - allowComments: Optional[Union[bool, str]] = ..., - allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., - allowSingleQuotes: Optional[Union[bool, str]] = ..., - allowNumericLeadingZero: Optional[Union[bool, str]] = ..., - allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., - mode: Optional[str] = ..., - columnNameOfCorruptRecord: Optional[str] = ..., - dateFormat: Optional[str] = ..., - timestampFormat: Optional[str] = ..., - multiLine: Optional[Union[bool, str]] = ..., - allowUnquotedControlChars: Optional[Union[bool, str]] = ..., - lineSep: Optional[str] = ..., - locale: Optional[str] = ..., - dropFieldIfAllNull: Optional[Union[bool, str]] = ..., - encoding: Optional[str] = ..., - pathGlobFilter: Optional[Union[bool, str]] = ..., - recursiveFileLookup: Optional[Union[bool, str]] = ..., - allowNonNumericNumbers: Optional[Union[bool, str]] = ..., - ) -> DataFrame: ... - def orc( - self, - path: str, - mergeSchema: Optional[bool] = ..., - pathGlobFilter: Optional[Union[bool, str]] = ..., - recursiveFileLookup: Optional[Union[bool, str]] = ..., - ) -> DataFrame: ... - def parquet( - self, - path: str, - mergeSchema: Optional[bool] = ..., - pathGlobFilter: Optional[Union[bool, str]] = ..., - recursiveFileLookup: Optional[Union[bool, str]] = ..., - ) -> DataFrame: ... - def text( - self, - path: str, - wholetext: bool = ..., - lineSep: Optional[str] = ..., - pathGlobFilter: Optional[Union[bool, str]] = ..., - recursiveFileLookup: Optional[Union[bool, str]] = ..., - ) -> DataFrame: ... - def csv( - self, - path: str, - schema: Optional[Union[StructType, str]] = ..., - sep: Optional[str] = ..., - encoding: Optional[str] = ..., - quote: Optional[str] = ..., - escape: Optional[str] = ..., - comment: Optional[str] = ..., - header: Optional[Union[bool, str]] = ..., - inferSchema: Optional[Union[bool, str]] = ..., - ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., - ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., - nullValue: Optional[str] = ..., - nanValue: Optional[str] = ..., - positiveInf: Optional[str] = ..., - negativeInf: Optional[str] = ..., - dateFormat: Optional[str] = ..., - timestampFormat: Optional[str] = ..., - maxColumns: Optional[Union[int, str]] = ..., - maxCharsPerColumn: Optional[Union[int, str]] = ..., - mode: Optional[str] = ..., - columnNameOfCorruptRecord: Optional[str] = ..., - multiLine: Optional[Union[bool, str]] = ..., - charToEscapeQuoteEscaping: Optional[Union[bool, str]] = ..., - enforceSchema: Optional[Union[bool, str]] = ..., - emptyValue: Optional[str] = ..., - locale: Optional[str] = ..., - lineSep: Optional[str] = ..., - pathGlobFilter: Optional[Union[bool, str]] = ..., - recursiveFileLookup: Optional[Union[bool, str]] = ..., - unescapedQuoteHandling: Optional[str] = ..., - ) -> DataFrame: ... - def table(self, tableName: str) -> DataFrame: ... - -class DataStreamWriter: - def __init__(self, df: DataFrame) -> None: ... - def outputMode(self, outputMode: str) -> DataStreamWriter: ... - def format(self, source: str) -> DataStreamWriter: ... - def option(self, key: str, value: OptionalPrimitiveType) -> DataStreamWriter: ... - def options(self, **options: OptionalPrimitiveType) -> DataStreamWriter: ... - @overload - def partitionBy(self, *cols: str) -> DataStreamWriter: ... - @overload - def partitionBy(self, __cols: List[str]) -> DataStreamWriter: ... - def queryName(self, queryName: str) -> DataStreamWriter: ... - @overload - def trigger(self, processingTime: str) -> DataStreamWriter: ... - @overload - def trigger(self, once: bool) -> DataStreamWriter: ... - @overload - def trigger(self, continuous: bool) -> DataStreamWriter: ... - def start( - self, - path: Optional[str] = ..., - format: Optional[str] = ..., - outputMode: Optional[str] = ..., - partitionBy: Optional[Union[str, List[str]]] = ..., - queryName: Optional[str] = ..., - **options: OptionalPrimitiveType - ) -> StreamingQuery: ... - @overload - def foreach(self, f: Callable[[Row], None]) -> DataStreamWriter: ... - @overload - def foreach(self, f: SupportsProcess) -> DataStreamWriter: ... - def foreachBatch( - self, func: Callable[[DataFrame, int], None] - ) -> DataStreamWriter: ... - def toTable( - self, - tableName: str, - format: Optional[str] = ..., - outputMode: Optional[str] = ..., - partitionBy: Optional[Union[str, List[str]]] = ..., - queryName: Optional[str] = ..., - **options: OptionalPrimitiveType - ) -> StreamingQuery: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org