This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 64eabb6 [SPARK-37048][PYTHON] Clean up inlining type hints under SQL module 64eabb6 is described below commit 64eabb6292baaaf18ee4e31cb48b204ef64aa488 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Oct 20 16:17:01 2021 -0700 [SPARK-37048][PYTHON] Clean up inlining type hints under SQL module ### What changes were proposed in this pull request? Cleans up inlining type hints under SQL module. ### Why are the changes needed? Now that most of type hits under the SQL module are inlined, we should clean up for the module now. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `lint-python` and existing tests should pass. Closes #34318 from ueshin/issues/SPARK-37048/cleanup. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/pandas/data_type_ops/base.py | 2 +- python/pyspark/pandas/frame.py | 2 +- python/pyspark/pandas/generic.py | 4 ++-- python/pyspark/pandas/spark/functions.py | 2 +- python/pyspark/pandas/window.py | 21 ++++++--------------- python/pyspark/sql/avro/functions.py | 2 +- python/pyspark/sql/catalog.py | 10 +++++----- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/conf.py | 2 +- python/pyspark/sql/context.py | 9 ++++----- python/pyspark/sql/dataframe.py | 28 +++++++++------------------- python/pyspark/sql/functions.py | 4 ++-- python/pyspark/sql/group.py | 13 +++---------- python/pyspark/sql/observation.py | 18 ++++++++---------- python/pyspark/sql/pandas/conversion.py | 6 +++--- python/pyspark/sql/pandas/group_ops.py | 4 ++-- python/pyspark/sql/readwriter.py | 27 +++++++-------------------- python/pyspark/sql/session.py | 16 ++++++++-------- python/pyspark/sql/streaming.py | 24 +++++++++++------------- python/pyspark/sql/tests/test_functions.py | 2 +- python/pyspark/sql/types.py | 20 +++++++++++++------- python/pyspark/sql/udf.py | 6 +++--- python/pyspark/sql/utils.py | 8 +++++--- python/pyspark/sql/window.py | 2 +- 24 files changed, 99 insertions(+), 135 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 47a6671..9a26d18 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -395,7 +395,7 @@ class DataTypeOps(object, metaclass=ABCMeta): collected_structed_scol = F.collect_list(structed_scol) # Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order. collected_structed_scol = F.array_sort(collected_structed_scol) - right_values_scol = F.array([F.lit(x) for x in right]) # type: ignore + right_values_scol = F.array(*(F.lit(x) for x in right)) index_scol_names = left._internal.index_spark_column_names scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0]) # Compare the values of left and right by using zip_with function. diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index c22e077..1eb91c3 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -6415,7 +6415,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})] 4 1 True 1.0 5 2 False 2.0 """ - from pyspark.sql.types import _parse_datatype_string # type: ignore[attr-defined] + from pyspark.sql.types import _parse_datatype_string include_list: List[str] if not is_list_like(include): diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 6d9379e..03019ac 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -874,7 +874,7 @@ class Frame(object, metaclass=ABCMeta): builder = sdf.write.mode(mode) if partition_cols is not None: builder.partitionBy(partition_cols) - builder._set_opts( # type: ignore[attr-defined] + builder._set_opts( sep=sep, nullValue=na_rep, header=header, @@ -1022,7 +1022,7 @@ class Frame(object, metaclass=ABCMeta): builder = sdf.write.mode(mode) if partition_cols is not None: builder.partitionBy(partition_cols) - builder._set_opts(compression=compression) # type: ignore[attr-defined] + builder._set_opts(compression=compression) builder.options(**options).format("json").save(path) return None diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 5e171c3..73251fc 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -23,7 +23,7 @@ import numpy as np from pyspark import SparkContext from pyspark.sql import functions as F -from pyspark.sql.column import ( # type: ignore[attr-defined] +from pyspark.sql.column import ( Column, _to_java_column, _create_column_from_literal, diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index ff18cf2..122cde6 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -16,14 +16,7 @@ # from abc import ABCMeta, abstractmethod from functools import partial -from typing import ( - Any, - Callable, - Generic, - List, - Optional, - cast, -) +from typing import Any, Callable, Generic, List, Optional from pyspark.sql import Window from pyspark.sql import functions as F @@ -162,12 +155,13 @@ class Rolling(RollingLike[FrameLike]): super().__init__(window, min_periods) + self._psdf_or_psser = psdf_or_psser + if not isinstance(psdf_or_psser, (DataFrame, Series)): raise TypeError( "psdf_or_psser must be a series or dataframe; however, got: %s" % type(psdf_or_psser) ) - self._psdf_or_psser = psdf_or_psser def __getattr__(self, item: str) -> Any: if hasattr(MissingPandasLikeRolling, item): @@ -179,12 +173,9 @@ class Rolling(RollingLike[FrameLike]): raise AttributeError(item) def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> FrameLike: - return cast( - FrameLike, - self._psdf_or_psser._apply_series_op( - lambda psser: psser._with_new_scol(func(psser.spark.column)), # TODO: dtype? - should_resolve=True, - ), + return self._psdf_or_psser._apply_series_op( + lambda psser: psser._with_new_scol(func(psser.spark.column)), # TODO: dtype? + should_resolve=True, ) def count(self) -> FrameLike: diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 27d2887..28c571b 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -22,7 +22,7 @@ A collections of builtin avro functions from typing import Dict, Optional, TYPE_CHECKING from pyspark import SparkContext -from pyspark.sql.column import Column, _to_java_column # type: ignore[attr-defined] +from pyspark.sql.column import Column, _to_java_column from pyspark.util import _print_missing_jar # type: ignore[attr-defined] if TYPE_CHECKING: diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 29f22e4..8a2c02e 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -68,8 +68,8 @@ class Catalog(object): def __init__(self, sparkSession: SparkSession) -> None: """Create a new Catalog that wraps the underlying JVM object.""" self._sparkSession = sparkSession - self._jsparkSession = sparkSession._jsparkSession # type: ignore[attr-defined] - self._jcatalog = sparkSession._jsparkSession.catalog() # type: ignore[attr-defined] + self._jsparkSession = sparkSession._jsparkSession + self._jcatalog = sparkSession._jsparkSession.catalog() @since(2.0) def currentDatabase(self) -> str: @@ -338,10 +338,10 @@ class Catalog(object): options["path"] = path if source is None: source = ( - self._sparkSession # type: ignore[attr-defined] + self._sparkSession ._wrapped ._conf - .defaultDataSourceName() + .defaultDataSourceName() # type: ignore[attr-defined] ) if description is None: description = "" @@ -353,7 +353,7 @@ class Catalog(object): scala_datatype = self._jsparkSession.parseDataType(schema.json()) df = self._jcatalog.createTable( tableName, source, scala_datatype, description, options) - return DataFrame(df, self._sparkSession._wrapped) # type: ignore[attr-defined] + return DataFrame(df, self._sparkSession._wrapped) def dropTempView(self, viewName: str) -> None: """Drops the local temporary view with the given view name in the catalog. diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index a3e3e9e..2420a88 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,7 +31,7 @@ from typing import ( Union ) -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import copy_func from pyspark.context import SparkContext diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 54ae6fb..e4b441a 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -18,7 +18,7 @@ import sys from typing import Any, Optional -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import since, _NoValue # type: ignore[attr-defined] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e6e4528..7d27c55 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -31,10 +31,10 @@ from typing import ( TYPE_CHECKING, cast ) -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import since, _NoValue # type: ignore[attr-defined] -from pyspark.sql.session import _monkey_patch_RDD, SparkSession # type: ignore[attr-defined] +from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader @@ -121,7 +121,7 @@ class SQLContext(object): if sparkSession is None: sparkSession = SparkSession.builder.getOrCreate() if jsqlContext is None: - jsqlContext = sparkSession._jwrapped # type: ignore[attr-defined] + jsqlContext = sparkSession._jwrapped self.sparkSession = sparkSession self._jsqlContext = jsqlContext _monkey_patch_RDD(self.sparkSession) @@ -727,8 +727,7 @@ class HiveContext(SQLContext): if jhiveContext is None: sparkContext._conf.set( # type: ignore[attr-defined] "spark.sql.catalogImplementation", "hive") - sparkSession = SparkSession.builder._sparkContext( # type: ignore[attr-defined] - sparkContext).getOrCreate() + sparkSession = SparkSession.builder._sparkContext(sparkContext).getOrCreate() else: sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 90311c4..269d1e6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -38,13 +38,12 @@ from pyspark.serializers import BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql.types import _parse_datatype_json_string # type: ignore[attr-defined] -from pyspark.sql.column import ( # type: ignore[attr-defined] - Column, _to_seq, _to_list, _to_java_column -) +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.streaming import DataStreamWriter -from pyspark.sql.types import StructType, StructField, StringType, IntegerType, Row +from pyspark.sql.types import ( + StructType, StructField, StringType, IntegerType, Row, _parse_datatype_json_string +) from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin @@ -90,10 +89,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"): self._jdf = jdf self.sql_ctx = sql_ctx - self._sc: SparkContext = cast( - SparkContext, - sql_ctx and sql_ctx._sc # type: ignore[attr-defined] - ) + self._sc: SparkContext = cast(SparkContext, sql_ctx and sql_ctx._sc) self.is_cached = False # initialized lazily self._schema: Optional[StructType] = None @@ -109,10 +105,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD( - jrdd, self.sql_ctx._sc, # type: ignore[attr-defined] - BatchedSerializer(PickleSerializer()) - ) + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property # type: ignore[misc] @@ -1280,10 +1273,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit( - _to_list( - self.sql_ctx._sc, # type: ignore[attr-defined] - cast(List["ColumnOrName"], weights) - ), + _to_list(self.sql_ctx._sc, cast(List["ColumnOrName"], weights)), int(seed) ) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @@ -1655,11 +1645,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] = None ) -> JavaObject: """Return a JVM Seq of Columns from a list of Column or names""" - return _to_seq(self.sql_ctx._sc, cols, converter) # type: ignore[attr-defined] + return _to_seq(self.sql_ctx._sc, cols, converter) def _jmap(self, jm: Dict) -> JavaObject: """Return a JVM Scala Map from a dict""" - return _to_scala_map(self.sql_ctx._sc, jm) # type: ignore[attr-defined] + return _to_scala_map(self.sql_ctx._sc, jm) def _jcols(self, *cols: "ColumnOrName") -> JavaObject: """Return a JVM Seq of Columns from a list of Column or column names diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2fd8b09..30ffac3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -38,7 +38,7 @@ from typing import ( from pyspark import since, SparkContext from pyspark.rdd import PythonEvalType -from pyspark.sql.column import ( # type: ignore[attr-defined] +from pyspark.sql.column import ( Column, _to_java_column, _to_seq, @@ -47,7 +47,7 @@ from pyspark.sql.column import ( # type: ignore[attr-defined] from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import ArrayType, DataType, StringType, StructType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 -from pyspark.sql.udf import UserDefinedFunction, _create_udf # type: ignore[attr-defined] # noqa: F401 +from pyspark.sql.udf import UserDefinedFunction, _create_udf # noqa: F401 # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401 from pyspark.sql.utils import to_str diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 183041f..547d022 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -21,7 +21,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING, overload, Dict, Unio from py4j.java_gateway import JavaObject # type: ignore[import] -from pyspark.sql.column import Column, _to_seq # type: ignore[attr-defined] +from pyspark.sql.column import Column, _to_seq from pyspark.sql.context import SQLContext from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin @@ -46,10 +46,7 @@ def dfapi(f: Callable) -> Callable: def df_varargs_api(f: Callable) -> Callable: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ - # TODO: ignore[attr-defined] will be removed, once SparkContext is inlined - jdf = getattr(self._jgd, name)( - _to_seq(self.sql_ctx._sc, cols) # type: ignore[attr-defined] - ) + jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -135,11 +132,7 @@ class GroupedData(PandasGroupedOpsMixin): # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" exprs = cast(Tuple[Column, ...], exprs) - # TODO: ignore[attr-defined] will be removed, once SparkContext is inlined - jdf = self._jgd.agg( - exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]) # type: ignore[attr-defined] - ) + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @dfapi diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index b563f44..1c48f29 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional + +from py4j.java_gateway import JavaObject, JVMView # type: ignore[import] from pyspark.sql import column from pyspark.sql.column import Column @@ -22,9 +24,6 @@ from pyspark.sql.dataframe import DataFrame __all__ = ["Observation"] -if TYPE_CHECKING: - from pyspark import SparkContext # noqa: F401 - class Observation: """Class to observe (named) metrics on a :class:`DataFrame`. @@ -80,8 +79,8 @@ class Observation: if name == '': raise ValueError("name should not be empty") self._name = name - self._jvm = None - self._jo = None + self._jvm: Optional[JVMView] = None + self._jo: Optional[JavaObject] = None def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: """Attaches this observation to the given :class:`DataFrame` to observe aggregations. @@ -102,14 +101,13 @@ class Observation: assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" assert self._jo is None, "an Observation can be used with a DataFrame only once" - self._jvm = df._sc._jvm # type: ignore[assignment, attr-defined] + self._jvm = df._sc._jvm # type: ignore[attr-defined] cls = self._jvm.org.apache.spark.sql.Observation # type: ignore[attr-defined] self._jo = cls(self._name) if self._name is not None else cls() - observed_df = self._jo.on( # type: ignore[attr-defined] + observed_df = self._jo.on( df._jdf, exprs[0]._jc, - column._to_seq( - df._sc, [c._jc for c in exprs[1:]]) # type: ignore[attr-defined] + column._to_seq(df._sc, [c._jc for c in exprs[1:]]) ) return DataFrame(observed_df, df.sql_ctx) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index a9700df..5b01684 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -376,7 +376,7 @@ class SparkConversionMixin(object): warnings.warn(msg) raise converted_data = self._convert_from_pandas(data, schema, timezone) - return self._create_dataframe( # type: ignore[attr-defined] + return self._create_dataframe( converted_data, schema, samplingRatio, verifySchema ) @@ -554,8 +554,8 @@ class SparkConversionMixin(object): self._jvm # type: ignore[attr-defined] .PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) ) - df = DataFrame(jdf, self._wrapped) # type: ignore[attr-defined] - df._schema = schema # type: ignore[attr-defined] + df = DataFrame(jdf, self._wrapped) + df._schema = schema return df diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 84db18f..97ff32c 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -211,7 +211,7 @@ class PandasGroupedOpsMixin(object): udf = pandas_udf( func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP) - df = self._df # type: ignore[attr-defined] + df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) # type: ignore[attr-defined] return DataFrame(jdf, self.sql_ctx) @@ -352,7 +352,7 @@ class PandasCogroupedOps(object): @staticmethod def _extract_cols(gd: "GroupedData") -> List[Column]: - df = gd._df # type: ignore[attr-defined] + df = gd._df return [df[col] for col in df.columns] diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 87b98ac..a31d701 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -27,7 +27,7 @@ from typing import ( Union ) -from py4j.java_gateway import JavaClass, JavaObject +from py4j.java_gateway import JavaClass, JavaObject # type: ignore[import] from pyspark import RDD, since from pyspark.sql.column import _to_seq, _to_java_column, Column @@ -358,8 +358,7 @@ class DataFrameReader(OptionUtils): modifiedAfter=modifiedAfter, datetimeRebaseMode=datetimeRebaseMode, int96RebaseMode=int96RebaseMode) - return self._df( - self._jreader.parquet(_to_seq(self._spark._sc, paths))) # type: ignore[attr-defined] + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) def text( self, @@ -571,8 +570,7 @@ class DataFrameReader(OptionUtils): recursiveFileLookup=recursiveFileLookup) if isinstance(path, str): path = [path] - return self._df( - self._jreader.orc(_to_seq(self._spark._sc, path))) # type: ignore[attr-defined] + return self._df(self._jreader.orc(_to_seq(self._spark._sc, path))) @overload def jdbc( @@ -783,10 +781,7 @@ class DataFrameWriter(OptionUtils): if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] # type: ignore[assignment] self._jwrite = self._jwrite.partitionBy( - _to_seq( - self._spark._sc, # type: ignore[attr-defined] - cast(Iterable["ColumnOrName"], cols) - ) + _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols)) ) return self @@ -848,10 +843,7 @@ class DataFrameWriter(OptionUtils): self._jwrite = self._jwrite.bucketBy( numBuckets, col, - _to_seq( - self._spark._sc, # type: ignore[attr-defined] - cast(Iterable["ColumnOrName"], cols) - ) + _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols)) ) return self @@ -897,11 +889,7 @@ class DataFrameWriter(OptionUtils): raise TypeError("all names should be `str`") self._jwrite = self._jwrite.sortBy( - col, - _to_seq( - self._spark._sc, # type: ignore[attr-defined] - cast(Iterable["ColumnOrName"], cols) - ) + col, _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols)) ) return self @@ -1385,8 +1373,7 @@ class DataFrameWriterV2(object): """ col = _to_java_column(col) - cols = _to_seq( - self._spark._sc, [_to_java_column(c) for c in cols]) # type: ignore[attr-defined] + cols = _to_seq(self._spark._sc, [_to_java_column(c) for c in cols]) return self @since(3.1) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c8ed108..728d658 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -21,7 +21,7 @@ from functools import reduce from threading import RLock from types import TracebackType from typing import ( - Any, Dict, Iterable, List, Optional, Tuple, Type, Union, + Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union, cast, no_type_check, overload, TYPE_CHECKING ) @@ -34,7 +34,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.conversion import SparkConversionMixin from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import ( # type: ignore[attr-defined] +from pyspark.sql.types import ( AtomicType, DataType, StructType, _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string @@ -127,7 +127,7 @@ class SparkSession(SparkConversionMixin): _lock = RLock() _options: Dict[str, Any] = {} - _sc = None + _sc: Optional[SparkContext] = None @overload def config(self, *, conf: SparkConf) -> "SparkSession.Builder": @@ -268,8 +268,8 @@ class SparkSession(SparkConversionMixin): builder = Builder() """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances.""" - _instantiatedSession = None - _activeSession = None + _instantiatedSession: ClassVar[Optional["SparkSession"]] = None + _activeSession: ClassVar[Optional["SparkSession"]] = None def __init__(self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None): from pyspark.sql.context import SQLContext @@ -292,7 +292,7 @@ class SparkSession(SparkConversionMixin): # which is stopped now, we need to renew the instantiated SparkSession. # Otherwise, we will use invalid SparkSession when we call Builder.getOrCreate. if SparkSession._instantiatedSession is None \ - or SparkSession._instantiatedSession._sc._jsc is None: + or SparkSession._instantiatedSession._sc._jsc is None: # type: ignore[attr-defined] SparkSession._instantiatedSession = self SparkSession._activeSession = self self._jvm.SparkSession.setDefaultSession(self._jsparkSession) @@ -845,7 +845,7 @@ class SparkSession(SparkConversionMixin): ) jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), struct.json()) df = DataFrame(jdf, self._wrapped) - df._schema = struct # type: ignore[attr-defined] + df._schema = struct return df def sql(self, sqlQuery: str) -> DataFrame: @@ -945,7 +945,7 @@ class SparkSession(SparkConversionMixin): self._jvm.SparkSession.clearActiveSession() SparkSession._instantiatedSession = None SparkSession._activeSession = None - SQLContext._instantiatedContext = None # type: ignore[attr-defined] + SQLContext._instantiatedContext = None @since(2.0) def __enter__(self) -> "SparkSession": diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24cd2db..94b78e7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -20,11 +20,11 @@ import json from collections.abc import Iterator from typing import cast, overload, Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union -from py4j.java_gateway import java_import, JavaObject +from py4j.java_gateway import java_import, JavaObject # type: ignore[import] from pyspark import since from pyspark.sql.column import _to_seq -from pyspark.sql.readwriter import OptionUtils, to_str # type: ignore[attr-defined] +from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import Row, StructType, StructField, StringType from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException @@ -319,7 +319,7 @@ class DataStreamReader(OptionUtils): """ def __init__(self, spark: "SQLContext") -> None: - self._jreader = spark._ssql_ctx.readStream() # type: ignore[attr-defined] + self._jreader = spark._ssql_ctx.readStream() self._spark = spark def _df(self, jdf: JavaObject) -> "DataFrame": @@ -532,7 +532,7 @@ class DataStreamReader(OptionUtils): >>> json_sdf.schema == sdf_schema True """ - self._set_opts( # type: ignore[attr-defined] + self._set_opts( schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, @@ -576,7 +576,7 @@ class DataStreamReader(OptionUtils): >>> orc_sdf.schema == sdf_schema True """ - self._set_opts( # type: ignore[attr-defined] + self._set_opts( mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup @@ -622,7 +622,7 @@ class DataStreamReader(OptionUtils): >>> parquet_sdf.schema == sdf_schema True """ - self._set_opts( # type: ignore[attr-defined] + self._set_opts( mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, @@ -678,7 +678,7 @@ class DataStreamReader(OptionUtils): >>> "value" in str(text_sdf.schema) True """ - self._set_opts( # type: ignore[attr-defined] + self._set_opts( wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) if isinstance(path, str): @@ -757,7 +757,7 @@ class DataStreamReader(OptionUtils): >>> csv_sdf.schema == sdf_schema True """ - self._set_opts( # type: ignore[attr-defined] + self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, @@ -929,7 +929,7 @@ class DataStreamWriter(object): if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] self._jwrite = self._jwrite.partitionBy( - _to_seq(self._spark._sc, cols)) # type: ignore[attr-defined] + _to_seq(self._spark._sc, cols)) return self def queryName(self, queryName: str) -> "DataStreamWriter": @@ -1215,9 +1215,7 @@ class DataStreamWriter(object): func = func_with_open_process_close # type: ignore[assignment] serializer = AutoBatchedSerializer(PickleSerializer()) - wrapped_func = _wrap_function( - self._spark._sc, # type: ignore[attr-defined] - func, serializer, serializer) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) jForeachWriter = ( self._spark._sc # type: ignore[attr-defined] ._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( @@ -1424,7 +1422,7 @@ def _test() -> None: import tempfile from pyspark.sql import SparkSession, SQLContext import pyspark.sql.streaming - from py4j.protocol import Py4JError + from py4j.protocol import Py4JError # type: ignore[import] os.chdir(os.environ["SPARK_HOME"]) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 61225ab..2aecff8 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -20,7 +20,7 @@ from itertools import chain import re import math -from py4j.protocol import Py4JJavaError +from py4j.protocol import Py4JJavaError # type: ignore[import] from pyspark.sql import Row, Window, types from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \ lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6b9a723..6b40482 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -31,6 +31,7 @@ from typing import ( overload, Any, Callable, + ClassVar, Dict, Iterator, List, @@ -41,8 +42,8 @@ from typing import ( TypeVar, ) -from py4j.protocol import register_input_converter -from py4j.java_gateway import JavaClass, JavaGateway, JavaObject +from py4j.protocol import register_input_converter # type: ignore[import] +from py4j.java_gateway import JavaClass, JavaGateway, JavaObject # type: ignore[import] from pyspark.serializers import CloudPickleSerializer @@ -112,9 +113,9 @@ class DataType(object): class DataTypeSingleton(type): """Metaclass for DataType""" - _instances: Dict[Type["DataTypeSingleton"], "DataTypeSingleton"] = {} + _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] = {} - def __call__(cls: Type[T]) -> T: # type: ignore[override, attr-defined] + def __call__(cls: Type[T]) -> T: # type: ignore[override] if cls not in cls._instances: # type: ignore[attr-defined] cls._instances[cls] = super( # type: ignore[misc, attr-defined] DataTypeSingleton, cls).__call__() @@ -843,8 +844,13 @@ _atomic_types: List[Type[DataType]] = [ ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, TimestampNTZType, NullType] _all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) -_complex_types: List[Type[DataType]] = [ArrayType, MapType, StructType] -_all_complex_types: Dict[str, Type[DataType]] = dict((v.typeName(), v) for v in _complex_types) + +_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ + ArrayType, MapType, StructType +] +_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( + (v.typeName(), v) for v in _complex_types +) _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") @@ -987,7 +993,7 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: else: tpe = json_value["type"] if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) # type: ignore[attr-defined] + return _all_complex_types[tpe].fromJson(json_value) elif tpe == 'udt': return UserDefinedType.fromJson(json_value) else: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 752ccca..b4b5218 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -21,12 +21,12 @@ import functools import sys from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType # type: ignore[attr-defined] from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import ( # type: ignore[attr-defined] +from pyspark.sql.types import ( StringType, DataType, StructType, @@ -273,7 +273,7 @@ class UDFRegistration(object): def register( self, name: str, - f: "Union[Callable[..., Any], UserDefinedFunctionLike]", + f: Union[Callable[..., Any], "UserDefinedFunctionLike"], returnType: Optional["DataTypeOrString"] = None, ) -> "UserDefinedFunctionLike": """Register a Python function (including lambda function) or a user-defined function diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index f5ade81..2be02b1 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -17,9 +17,11 @@ from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast import py4j -from py4j.java_collections import JavaArray -from py4j.java_gateway import JavaClass, JavaGateway, JavaObject, is_instance_of -from py4j.protocol import Py4JJavaError +from py4j.java_collections import JavaArray # type: ignore[import] +from py4j.java_gateway import ( # type: ignore[import] + JavaClass, JavaGateway, JavaObject, is_instance_of +) +from py4j.protocol import Py4JJavaError # type: ignore[import] from pyspark import SparkContext diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index f1b03ab..97d531f 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -19,7 +19,7 @@ import sys from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column # type: ignore[attr-defined] +from pyspark.sql.column import _to_seq, _to_java_column from py4j.java_gateway import JavaObject # type: ignore[import] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org