This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 64eabb6  [SPARK-37048][PYTHON] Clean up inlining type hints under SQL 
module
64eabb6 is described below

commit 64eabb6292baaaf18ee4e31cb48b204ef64aa488
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Wed Oct 20 16:17:01 2021 -0700

    [SPARK-37048][PYTHON] Clean up inlining type hints under SQL module
    
    ### What changes were proposed in this pull request?
    
    Cleans up inlining type hints under SQL module.
    
    ### Why are the changes needed?
    
    Now that most of type hits under the SQL module are inlined, we should 
clean up for the module now.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    `lint-python` and existing tests should pass.
    
    Closes #34318 from ueshin/issues/SPARK-37048/cleanup.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/pandas/data_type_ops/base.py |  2 +-
 python/pyspark/pandas/frame.py              |  2 +-
 python/pyspark/pandas/generic.py            |  4 ++--
 python/pyspark/pandas/spark/functions.py    |  2 +-
 python/pyspark/pandas/window.py             | 21 ++++++---------------
 python/pyspark/sql/avro/functions.py        |  2 +-
 python/pyspark/sql/catalog.py               | 10 +++++-----
 python/pyspark/sql/column.py                |  2 +-
 python/pyspark/sql/conf.py                  |  2 +-
 python/pyspark/sql/context.py               |  9 ++++-----
 python/pyspark/sql/dataframe.py             | 28 +++++++++-------------------
 python/pyspark/sql/functions.py             |  4 ++--
 python/pyspark/sql/group.py                 | 13 +++----------
 python/pyspark/sql/observation.py           | 18 ++++++++----------
 python/pyspark/sql/pandas/conversion.py     |  6 +++---
 python/pyspark/sql/pandas/group_ops.py      |  4 ++--
 python/pyspark/sql/readwriter.py            | 27 +++++++--------------------
 python/pyspark/sql/session.py               | 16 ++++++++--------
 python/pyspark/sql/streaming.py             | 24 +++++++++++-------------
 python/pyspark/sql/tests/test_functions.py  |  2 +-
 python/pyspark/sql/types.py                 | 20 +++++++++++++-------
 python/pyspark/sql/udf.py                   |  6 +++---
 python/pyspark/sql/utils.py                 |  8 +++++---
 python/pyspark/sql/window.py                |  2 +-
 24 files changed, 99 insertions(+), 135 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/base.py 
b/python/pyspark/pandas/data_type_ops/base.py
index 47a6671..9a26d18 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -395,7 +395,7 @@ class DataTypeOps(object, metaclass=ABCMeta):
             collected_structed_scol = F.collect_list(structed_scol)
             # Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee 
the order.
             collected_structed_scol = F.array_sort(collected_structed_scol)
-            right_values_scol = F.array([F.lit(x) for x in right])  # type: 
ignore
+            right_values_scol = F.array(*(F.lit(x) for x in right))
             index_scol_names = left._internal.index_spark_column_names
             scol_name = 
left._internal.spark_column_name_for(left._internal.column_labels[0])
             # Compare the values of left and right by using zip_with function.
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index c22e077..1eb91c3 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -6415,7 +6415,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         4  1   True  1.0
         5  2  False  2.0
         """
-        from pyspark.sql.types import _parse_datatype_string  # type: 
ignore[attr-defined]
+        from pyspark.sql.types import _parse_datatype_string
 
         include_list: List[str]
         if not is_list_like(include):
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 6d9379e..03019ac 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -874,7 +874,7 @@ class Frame(object, metaclass=ABCMeta):
         builder = sdf.write.mode(mode)
         if partition_cols is not None:
             builder.partitionBy(partition_cols)
-        builder._set_opts(  # type: ignore[attr-defined]
+        builder._set_opts(
             sep=sep,
             nullValue=na_rep,
             header=header,
@@ -1022,7 +1022,7 @@ class Frame(object, metaclass=ABCMeta):
         builder = sdf.write.mode(mode)
         if partition_cols is not None:
             builder.partitionBy(partition_cols)
-        builder._set_opts(compression=compression)  # type: 
ignore[attr-defined]
+        builder._set_opts(compression=compression)
         builder.options(**options).format("json").save(path)
         return None
 
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index 5e171c3..73251fc 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -23,7 +23,7 @@ import numpy as np
 
 from pyspark import SparkContext
 from pyspark.sql import functions as F
-from pyspark.sql.column import (  # type: ignore[attr-defined]
+from pyspark.sql.column import (
     Column,
     _to_java_column,
     _create_column_from_literal,
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index ff18cf2..122cde6 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -16,14 +16,7 @@
 #
 from abc import ABCMeta, abstractmethod
 from functools import partial
-from typing import (
-    Any,
-    Callable,
-    Generic,
-    List,
-    Optional,
-    cast,
-)
+from typing import Any, Callable, Generic, List, Optional
 
 from pyspark.sql import Window
 from pyspark.sql import functions as F
@@ -162,12 +155,13 @@ class Rolling(RollingLike[FrameLike]):
 
         super().__init__(window, min_periods)
 
+        self._psdf_or_psser = psdf_or_psser
+
         if not isinstance(psdf_or_psser, (DataFrame, Series)):
             raise TypeError(
                 "psdf_or_psser must be a series or dataframe; however, got: %s"
                 % type(psdf_or_psser)
             )
-        self._psdf_or_psser = psdf_or_psser
 
     def __getattr__(self, item: str) -> Any:
         if hasattr(MissingPandasLikeRolling, item):
@@ -179,12 +173,9 @@ class Rolling(RollingLike[FrameLike]):
         raise AttributeError(item)
 
     def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> 
FrameLike:
-        return cast(
-            FrameLike,
-            self._psdf_or_psser._apply_series_op(
-                lambda psser: psser._with_new_scol(func(psser.spark.column)),  
# TODO: dtype?
-                should_resolve=True,
-            ),
+        return self._psdf_or_psser._apply_series_op(
+            lambda psser: psser._with_new_scol(func(psser.spark.column)),  # 
TODO: dtype?
+            should_resolve=True,
         )
 
     def count(self) -> FrameLike:
diff --git a/python/pyspark/sql/avro/functions.py 
b/python/pyspark/sql/avro/functions.py
index 27d2887..28c571b 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -22,7 +22,7 @@ A collections of builtin avro functions
 
 from typing import Dict, Optional, TYPE_CHECKING
 from pyspark import SparkContext
-from pyspark.sql.column import Column, _to_java_column  # type: 
ignore[attr-defined]
+from pyspark.sql.column import Column, _to_java_column
 from pyspark.util import _print_missing_jar  # type: ignore[attr-defined]
 
 if TYPE_CHECKING:
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 29f22e4..8a2c02e 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -68,8 +68,8 @@ class Catalog(object):
     def __init__(self, sparkSession: SparkSession) -> None:
         """Create a new Catalog that wraps the underlying JVM object."""
         self._sparkSession = sparkSession
-        self._jsparkSession = sparkSession._jsparkSession  # type: 
ignore[attr-defined]
-        self._jcatalog = sparkSession._jsparkSession.catalog()  # type: 
ignore[attr-defined]
+        self._jsparkSession = sparkSession._jsparkSession
+        self._jcatalog = sparkSession._jsparkSession.catalog()
 
     @since(2.0)
     def currentDatabase(self) -> str:
@@ -338,10 +338,10 @@ class Catalog(object):
             options["path"] = path
         if source is None:
             source = (
-                self._sparkSession  # type: ignore[attr-defined]
+                self._sparkSession
                 ._wrapped
                 ._conf
-                .defaultDataSourceName()
+                .defaultDataSourceName()  # type: ignore[attr-defined]
             )
         if description is None:
             description = ""
@@ -353,7 +353,7 @@ class Catalog(object):
             scala_datatype = self._jsparkSession.parseDataType(schema.json())
             df = self._jcatalog.createTable(
                 tableName, source, scala_datatype, description, options)
-        return DataFrame(df, self._sparkSession._wrapped)  # type: 
ignore[attr-defined]
+        return DataFrame(df, self._sparkSession._wrapped)
 
     def dropTempView(self, viewName: str) -> None:
         """Drops the local temporary view with the given view name in the 
catalog.
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index a3e3e9e..2420a88 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -31,7 +31,7 @@ from typing import (
     Union
 )
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 from pyspark import copy_func
 from pyspark.context import SparkContext
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 54ae6fb..e4b441a 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -18,7 +18,7 @@
 import sys
 from typing import Any, Optional
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 from pyspark import since, _NoValue  # type: ignore[attr-defined]
 
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index e6e4528..7d27c55 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -31,10 +31,10 @@ from typing import (
     TYPE_CHECKING, cast
 )
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 from pyspark import since, _NoValue  # type: ignore[attr-defined]
-from pyspark.sql.session import _monkey_patch_RDD, SparkSession  # type: 
ignore[attr-defined]
+from pyspark.sql.session import _monkey_patch_RDD, SparkSession
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.streaming import DataStreamReader
@@ -121,7 +121,7 @@ class SQLContext(object):
         if sparkSession is None:
             sparkSession = SparkSession.builder.getOrCreate()
         if jsqlContext is None:
-            jsqlContext = sparkSession._jwrapped  # type: ignore[attr-defined]
+            jsqlContext = sparkSession._jwrapped
         self.sparkSession = sparkSession
         self._jsqlContext = jsqlContext
         _monkey_patch_RDD(self.sparkSession)
@@ -727,8 +727,7 @@ class HiveContext(SQLContext):
         if jhiveContext is None:
             sparkContext._conf.set(  # type: ignore[attr-defined]
                 "spark.sql.catalogImplementation", "hive")
-            sparkSession = SparkSession.builder._sparkContext(  # type: 
ignore[attr-defined]
-                sparkContext).getOrCreate()
+            sparkSession = 
SparkSession.builder._sparkContext(sparkContext).getOrCreate()
         else:
             sparkSession = SparkSession(sparkContext, 
jhiveContext.sparkSession())
         SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 90311c4..269d1e6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -38,13 +38,12 @@ from pyspark.serializers import BatchedSerializer, 
PickleSerializer, \
     UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.sql.types import _parse_datatype_json_string  # type: 
ignore[attr-defined]
-from pyspark.sql.column import (  # type: ignore[attr-defined]
-    Column, _to_seq, _to_list, _to_java_column
-)
+from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.streaming import DataStreamWriter
-from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType, Row
+from pyspark.sql.types import (
+    StructType, StructField, StringType, IntegerType, Row, 
_parse_datatype_json_string
+)
 from pyspark.sql.pandas.conversion import PandasConversionMixin
 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
@@ -90,10 +89,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
     def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"):
         self._jdf = jdf
         self.sql_ctx = sql_ctx
-        self._sc: SparkContext = cast(
-            SparkContext,
-            sql_ctx and sql_ctx._sc  # type: ignore[attr-defined]
-        )
+        self._sc: SparkContext = cast(SparkContext, sql_ctx and sql_ctx._sc)
         self.is_cached = False
         # initialized lazily
         self._schema: Optional[StructType] = None
@@ -109,10 +105,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         if self._lazy_rdd is None:
             jrdd = self._jdf.javaToPython()
-            self._lazy_rdd = RDD(
-                jrdd, self.sql_ctx._sc,  # type: ignore[attr-defined]
-                BatchedSerializer(PickleSerializer())
-            )
+            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, 
BatchedSerializer(PickleSerializer()))
         return self._lazy_rdd
 
     @property  # type: ignore[misc]
@@ -1280,10 +1273,7 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
                 raise ValueError("Weights must be positive. Found weight 
value: %s" % w)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         rdd_array = self._jdf.randomSplit(
-            _to_list(
-                self.sql_ctx._sc,  # type: ignore[attr-defined]
-                cast(List["ColumnOrName"], weights)
-            ),
+            _to_list(self.sql_ctx._sc, cast(List["ColumnOrName"], weights)),
             int(seed)
         )
         return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
@@ -1655,11 +1645,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] 
= None
     ) -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or names"""
-        return _to_seq(self.sql_ctx._sc, cols, converter)  # type: 
ignore[attr-defined]
+        return _to_seq(self.sql_ctx._sc, cols, converter)
 
     def _jmap(self, jm: Dict) -> JavaObject:
         """Return a JVM Scala Map from a dict"""
-        return _to_scala_map(self.sql_ctx._sc, jm)  # type: 
ignore[attr-defined]
+        return _to_scala_map(self.sql_ctx._sc, jm)
 
     def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or column names
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2fd8b09..30ffac3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -38,7 +38,7 @@ from typing import (
 
 from pyspark import since, SparkContext
 from pyspark.rdd import PythonEvalType
-from pyspark.sql.column import (  # type: ignore[attr-defined]
+from pyspark.sql.column import (
     Column,
     _to_java_column,
     _to_seq,
@@ -47,7 +47,7 @@ from pyspark.sql.column import (  # type: ignore[attr-defined]
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.types import ArrayType, DataType, StringType, StructType
 # Keep UserDefinedFunction import for backwards compatible import; moved in 
SPARK-22409
-from pyspark.sql.udf import UserDefinedFunction, _create_udf  # type: 
ignore[attr-defined] # noqa: F401
+from pyspark.sql.udf import UserDefinedFunction, _create_udf  # noqa: F401
 # Keep pandas_udf and PandasUDFType import for backwards compatible import; 
moved in SPARK-28264
 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType  # noqa: 
F401
 from pyspark.sql.utils import to_str
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 183041f..547d022 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -21,7 +21,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING, 
overload, Dict, Unio
 
 from py4j.java_gateway import JavaObject  # type: ignore[import]
 
-from pyspark.sql.column import Column, _to_seq  # type: ignore[attr-defined]
+from pyspark.sql.column import Column, _to_seq
 from pyspark.sql.context import SQLContext
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
@@ -46,10 +46,7 @@ def dfapi(f: Callable) -> Callable:
 def df_varargs_api(f: Callable) -> Callable:
     def _api(self: "GroupedData", *cols: str) -> DataFrame:
         name = f.__name__
-        # TODO: ignore[attr-defined] will be removed, once SparkContext is 
inlined
-        jdf = getattr(self._jgd, name)(
-            _to_seq(self.sql_ctx._sc, cols)  # type: ignore[attr-defined]
-        )
+        jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
         return DataFrame(jdf, self.sql_ctx)
     _api.__name__ = f.__name__
     _api.__doc__ = f.__doc__
@@ -135,11 +132,7 @@ class GroupedData(PandasGroupedOpsMixin):
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
             exprs = cast(Tuple[Column, ...], exprs)
-            # TODO: ignore[attr-defined] will be removed, once SparkContext is 
inlined
-            jdf = self._jgd.agg(
-                exprs[0]._jc,
-                _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])  # type: 
ignore[attr-defined]
-            )
+            jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc 
for c in exprs[1:]]))
         return DataFrame(jdf, self.sql_ctx)
 
     @dfapi
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index b563f44..1c48f29 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, Dict, Optional
+
+from py4j.java_gateway import JavaObject, JVMView  # type: ignore[import]
 
 from pyspark.sql import column
 from pyspark.sql.column import Column
@@ -22,9 +24,6 @@ from pyspark.sql.dataframe import DataFrame
 
 __all__ = ["Observation"]
 
-if TYPE_CHECKING:
-    from pyspark import SparkContext  # noqa: F401
-
 
 class Observation:
     """Class to observe (named) metrics on a :class:`DataFrame`.
@@ -80,8 +79,8 @@ class Observation:
             if name == '':
                 raise ValueError("name should not be empty")
         self._name = name
-        self._jvm = None
-        self._jo = None
+        self._jvm: Optional[JVMView] = None
+        self._jo: Optional[JavaObject] = None
 
     def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
         """Attaches this observation to the given :class:`DataFrame` to 
observe aggregations.
@@ -102,14 +101,13 @@ class Observation:
         assert all(isinstance(c, Column) for c in exprs), "all exprs should be 
Column"
         assert self._jo is None, "an Observation can be used with a DataFrame 
only once"
 
-        self._jvm = df._sc._jvm  # type: ignore[assignment, attr-defined]
+        self._jvm = df._sc._jvm  # type: ignore[attr-defined]
         cls = self._jvm.org.apache.spark.sql.Observation  # type: 
ignore[attr-defined]
         self._jo = cls(self._name) if self._name is not None else cls()
-        observed_df = self._jo.on(  # type: ignore[attr-defined]
+        observed_df = self._jo.on(
             df._jdf,
             exprs[0]._jc,
-            column._to_seq(
-                df._sc, [c._jc for c in exprs[1:]])  # type: 
ignore[attr-defined]
+            column._to_seq(df._sc, [c._jc for c in exprs[1:]])
         )
         return DataFrame(observed_df, df.sql_ctx)
 
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index a9700df..5b01684 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -376,7 +376,7 @@ class SparkConversionMixin(object):
                     warnings.warn(msg)
                     raise
         converted_data = self._convert_from_pandas(data, schema, timezone)
-        return self._create_dataframe(  # type: ignore[attr-defined]
+        return self._create_dataframe(
             converted_data, schema, samplingRatio, verifySchema
         )
 
@@ -554,8 +554,8 @@ class SparkConversionMixin(object):
             self._jvm  # type: ignore[attr-defined]
                 .PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
         )
-        df = DataFrame(jdf, self._wrapped)  # type: ignore[attr-defined]
-        df._schema = schema  # type: ignore[attr-defined]
+        df = DataFrame(jdf, self._wrapped)
+        df._schema = schema
         return df
 
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 84db18f..97ff32c 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -211,7 +211,7 @@ class PandasGroupedOpsMixin(object):
 
         udf = pandas_udf(
             func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
-        df = self._df  # type: ignore[attr-defined]
+        df = self._df
         udf_column = udf(*[df[col] for col in df.columns])
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())  # type: 
ignore[attr-defined]
         return DataFrame(jdf, self.sql_ctx)
@@ -352,7 +352,7 @@ class PandasCogroupedOps(object):
 
     @staticmethod
     def _extract_cols(gd: "GroupedData") -> List[Column]:
-        df = gd._df  # type: ignore[attr-defined]
+        df = gd._df
         return [df[col] for col in df.columns]
 
 
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 87b98ac..a31d701 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -27,7 +27,7 @@ from typing import (
     Union
 )
 
-from py4j.java_gateway import JavaClass, JavaObject
+from py4j.java_gateway import JavaClass, JavaObject  # type: ignore[import]
 
 from pyspark import RDD, since
 from pyspark.sql.column import _to_seq, _to_java_column, Column
@@ -358,8 +358,7 @@ class DataFrameReader(OptionUtils):
                        modifiedAfter=modifiedAfter, 
datetimeRebaseMode=datetimeRebaseMode,
                        int96RebaseMode=int96RebaseMode)
 
-        return self._df(
-            self._jreader.parquet(_to_seq(self._spark._sc, paths)))  # type: 
ignore[attr-defined]
+        return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
 
     def text(
         self,
@@ -571,8 +570,7 @@ class DataFrameReader(OptionUtils):
                        recursiveFileLookup=recursiveFileLookup)
         if isinstance(path, str):
             path = [path]
-        return self._df(
-            self._jreader.orc(_to_seq(self._spark._sc, path)))  # type: 
ignore[attr-defined]
+        return self._df(self._jreader.orc(_to_seq(self._spark._sc, path)))
 
     @overload
     def jdbc(
@@ -783,10 +781,7 @@ class DataFrameWriter(OptionUtils):
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]  # type: ignore[assignment]
         self._jwrite = self._jwrite.partitionBy(
-            _to_seq(
-                self._spark._sc,  # type: ignore[attr-defined]
-                cast(Iterable["ColumnOrName"], cols)
-            )
+            _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols))
         )
         return self
 
@@ -848,10 +843,7 @@ class DataFrameWriter(OptionUtils):
         self._jwrite = self._jwrite.bucketBy(
             numBuckets,
             col,
-            _to_seq(
-                self._spark._sc,    # type: ignore[attr-defined]
-                cast(Iterable["ColumnOrName"], cols)
-            )
+            _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols))
         )
         return self
 
@@ -897,11 +889,7 @@ class DataFrameWriter(OptionUtils):
             raise TypeError("all names should be `str`")
 
         self._jwrite = self._jwrite.sortBy(
-            col,
-            _to_seq(
-                self._spark._sc,  # type: ignore[attr-defined]
-                cast(Iterable["ColumnOrName"], cols)
-            )
+            col, _to_seq(self._spark._sc, cast(Iterable["ColumnOrName"], cols))
         )
         return self
 
@@ -1385,8 +1373,7 @@ class DataFrameWriterV2(object):
 
         """
         col = _to_java_column(col)
-        cols = _to_seq(
-            self._spark._sc, [_to_java_column(c) for c in cols])  # type: 
ignore[attr-defined]
+        cols = _to_seq(self._spark._sc, [_to_java_column(c) for c in cols])
         return self
 
     @since(3.1)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index c8ed108..728d658 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -21,7 +21,7 @@ from functools import reduce
 from threading import RLock
 from types import TracebackType
 from typing import (
-    Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+    Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union,
     cast, no_type_check, overload, TYPE_CHECKING
 )
 
@@ -34,7 +34,7 @@ from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.pandas.conversion import SparkConversionMixin
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import (  # type: ignore[attr-defined]
+from pyspark.sql.types import (
     AtomicType, DataType, StructType,
     _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, 
_create_converter,
     _parse_datatype_string
@@ -127,7 +127,7 @@ class SparkSession(SparkConversionMixin):
 
         _lock = RLock()
         _options: Dict[str, Any] = {}
-        _sc = None
+        _sc: Optional[SparkContext] = None
 
         @overload
         def config(self, *, conf: SparkConf) -> "SparkSession.Builder":
@@ -268,8 +268,8 @@ class SparkSession(SparkConversionMixin):
     builder = Builder()
     """A class attribute having a :class:`Builder` to construct 
:class:`SparkSession` instances."""
 
-    _instantiatedSession = None
-    _activeSession = None
+    _instantiatedSession: ClassVar[Optional["SparkSession"]] = None
+    _activeSession: ClassVar[Optional["SparkSession"]] = None
 
     def __init__(self, sparkContext: SparkContext, jsparkSession: 
Optional[JavaObject] = None):
         from pyspark.sql.context import SQLContext
@@ -292,7 +292,7 @@ class SparkSession(SparkConversionMixin):
         # which is stopped now, we need to renew the instantiated SparkSession.
         # Otherwise, we will use invalid SparkSession when we call 
Builder.getOrCreate.
         if SparkSession._instantiatedSession is None \
-                or SparkSession._instantiatedSession._sc._jsc is None:
+                or SparkSession._instantiatedSession._sc._jsc is None:  # 
type: ignore[attr-defined]
             SparkSession._instantiatedSession = self
             SparkSession._activeSession = self
             self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
@@ -845,7 +845,7 @@ class SparkSession(SparkConversionMixin):
         )
         jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), 
struct.json())
         df = DataFrame(jdf, self._wrapped)
-        df._schema = struct  # type: ignore[attr-defined]
+        df._schema = struct
         return df
 
     def sql(self, sqlQuery: str) -> DataFrame:
@@ -945,7 +945,7 @@ class SparkSession(SparkConversionMixin):
         self._jvm.SparkSession.clearActiveSession()
         SparkSession._instantiatedSession = None
         SparkSession._activeSession = None
-        SQLContext._instantiatedContext = None  # type: ignore[attr-defined]
+        SQLContext._instantiatedContext = None
 
     @since(2.0)
     def __enter__(self) -> "SparkSession":
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 24cd2db..94b78e7 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -20,11 +20,11 @@ import json
 from collections.abc import Iterator
 from typing import cast, overload, Any, Callable, Dict, List, Optional, 
TYPE_CHECKING, Union
 
-from py4j.java_gateway import java_import, JavaObject
+from py4j.java_gateway import java_import, JavaObject  # type: ignore[import]
 
 from pyspark import since
 from pyspark.sql.column import _to_seq
-from pyspark.sql.readwriter import OptionUtils, to_str  # type: 
ignore[attr-defined]
+from pyspark.sql.readwriter import OptionUtils, to_str
 from pyspark.sql.types import Row, StructType, StructField, StringType
 from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
 
@@ -319,7 +319,7 @@ class DataStreamReader(OptionUtils):
     """
 
     def __init__(self, spark: "SQLContext") -> None:
-        self._jreader = spark._ssql_ctx.readStream()  # type: 
ignore[attr-defined]
+        self._jreader = spark._ssql_ctx.readStream()
         self._spark = spark
 
     def _df(self, jdf: JavaObject) -> "DataFrame":
@@ -532,7 +532,7 @@ class DataStreamReader(OptionUtils):
         >>> json_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(  # type: ignore[attr-defined]
+        self._set_opts(
             schema=schema, primitivesAsString=primitivesAsString, 
prefersDecimal=prefersDecimal,
             allowComments=allowComments, 
allowUnquotedFieldNames=allowUnquotedFieldNames,
             allowSingleQuotes=allowSingleQuotes, 
allowNumericLeadingZero=allowNumericLeadingZero,
@@ -576,7 +576,7 @@ class DataStreamReader(OptionUtils):
         >>> orc_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(   # type: ignore[attr-defined]
+        self._set_opts(
             mergeSchema=mergeSchema,
             pathGlobFilter=pathGlobFilter,
             recursiveFileLookup=recursiveFileLookup
@@ -622,7 +622,7 @@ class DataStreamReader(OptionUtils):
         >>> parquet_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(  # type: ignore[attr-defined]
+        self._set_opts(
             mergeSchema=mergeSchema,
             pathGlobFilter=pathGlobFilter,
             recursiveFileLookup=recursiveFileLookup,
@@ -678,7 +678,7 @@ class DataStreamReader(OptionUtils):
         >>> "value" in str(text_sdf.schema)
         True
         """
-        self._set_opts(  # type: ignore[attr-defined]
+        self._set_opts(
             wholetext=wholetext, lineSep=lineSep, 
pathGlobFilter=pathGlobFilter,
             recursiveFileLookup=recursiveFileLookup)
         if isinstance(path, str):
@@ -757,7 +757,7 @@ class DataStreamReader(OptionUtils):
         >>> csv_sdf.schema == sdf_schema
         True
         """
-        self._set_opts(  # type: ignore[attr-defined]
+        self._set_opts(
             schema=schema, sep=sep, encoding=encoding, quote=quote, 
escape=escape, comment=comment,
             header=header, inferSchema=inferSchema, 
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
             ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, 
nullValue=nullValue,
@@ -929,7 +929,7 @@ class DataStreamWriter(object):
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]
         self._jwrite = self._jwrite.partitionBy(
-            _to_seq(self._spark._sc, cols))  # type: ignore[attr-defined]
+            _to_seq(self._spark._sc, cols))
         return self
 
     def queryName(self, queryName: str) -> "DataStreamWriter":
@@ -1215,9 +1215,7 @@ class DataStreamWriter(object):
             func = func_with_open_process_close  # type: ignore[assignment]
 
         serializer = AutoBatchedSerializer(PickleSerializer())
-        wrapped_func = _wrap_function(
-            self._spark._sc,  # type: ignore[attr-defined]
-            func, serializer, serializer)
+        wrapped_func = _wrap_function(self._spark._sc, func, serializer, 
serializer)
         jForeachWriter = (
             self._spark._sc    # type: ignore[attr-defined]
             ._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
@@ -1424,7 +1422,7 @@ def _test() -> None:
     import tempfile
     from pyspark.sql import SparkSession, SQLContext
     import pyspark.sql.streaming
-    from py4j.protocol import Py4JError
+    from py4j.protocol import Py4JError  # type: ignore[import]
 
     os.chdir(os.environ["SPARK_HOME"])
 
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 61225ab..2aecff8 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -20,7 +20,7 @@ from itertools import chain
 import re
 import math
 
-from py4j.protocol import Py4JJavaError
+from py4j.protocol import Py4JJavaError  # type: ignore[import]
 from pyspark.sql import Row, Window, types
 from pyspark.sql.functions import udf, input_file_name, col, 
percentile_approx, \
     lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, 
shiftRight, \
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 6b9a723..6b40482 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -31,6 +31,7 @@ from typing import (
     overload,
     Any,
     Callable,
+    ClassVar,
     Dict,
     Iterator,
     List,
@@ -41,8 +42,8 @@ from typing import (
     TypeVar,
 )
 
-from py4j.protocol import register_input_converter
-from py4j.java_gateway import JavaClass, JavaGateway, JavaObject
+from py4j.protocol import register_input_converter  # type: ignore[import]
+from py4j.java_gateway import JavaClass, JavaGateway, JavaObject  # type: 
ignore[import]
 
 from pyspark.serializers import CloudPickleSerializer
 
@@ -112,9 +113,9 @@ class DataType(object):
 class DataTypeSingleton(type):
     """Metaclass for DataType"""
 
-    _instances: Dict[Type["DataTypeSingleton"], "DataTypeSingleton"] = {}
+    _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] 
= {}
 
-    def __call__(cls: Type[T]) -> T:  # type: ignore[override, attr-defined]
+    def __call__(cls: Type[T]) -> T:  # type: ignore[override]
         if cls not in cls._instances:  # type: ignore[attr-defined]
             cls._instances[cls] = super(  # type: ignore[misc, attr-defined]
                 DataTypeSingleton, cls).__call__()
@@ -843,8 +844,13 @@ _atomic_types: List[Type[DataType]] = [
     ByteType, ShortType, IntegerType, LongType, DateType, TimestampType,
     TimestampNTZType, NullType]
 _all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in 
_atomic_types)
-_complex_types: List[Type[DataType]] = [ArrayType, MapType, StructType]
-_all_complex_types: Dict[str, Type[DataType]] = dict((v.typeName(), v) for v 
in _complex_types)
+
+_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [
+    ArrayType, MapType, StructType
+]
+_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = 
dict(
+    (v.typeName(), v) for v in _complex_types
+)
 
 
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
@@ -987,7 +993,7 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
     else:
         tpe = json_value["type"]
         if tpe in _all_complex_types:
-            return _all_complex_types[tpe].fromJson(json_value)  # type: 
ignore[attr-defined]
+            return _all_complex_types[tpe].fromJson(json_value)
         elif tpe == 'udt':
             return UserDefinedType.fromJson(json_value)
         else:
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 752ccca..b4b5218 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -21,12 +21,12 @@ import functools
 import sys
 from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union
 
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 from pyspark import SparkContext
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType  # type: 
ignore[attr-defined]
 from pyspark.sql.column import Column, _to_java_column, _to_seq
-from pyspark.sql.types import (  # type: ignore[attr-defined]
+from pyspark.sql.types import (
     StringType,
     DataType,
     StructType,
@@ -273,7 +273,7 @@ class UDFRegistration(object):
     def register(
         self,
         name: str,
-        f: "Union[Callable[..., Any], UserDefinedFunctionLike]",
+        f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
         returnType: Optional["DataTypeOrString"] = None,
     ) -> "UserDefinedFunctionLike":
         """Register a Python function (including lambda function) or a 
user-defined function
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index f5ade81..2be02b1 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -17,9 +17,11 @@
 from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast
 
 import py4j
-from py4j.java_collections import JavaArray
-from py4j.java_gateway import JavaClass, JavaGateway, JavaObject, 
is_instance_of
-from py4j.protocol import Py4JJavaError
+from py4j.java_collections import JavaArray  # type: ignore[import]
+from py4j.java_gateway import (  # type: ignore[import]
+    JavaClass, JavaGateway, JavaObject, is_instance_of
+)
+from py4j.protocol import Py4JJavaError  # type: ignore[import]
 
 from pyspark import SparkContext
 
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index f1b03ab..97d531f 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -19,7 +19,7 @@ import sys
 from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
 
 from pyspark import since, SparkContext
-from pyspark.sql.column import _to_seq, _to_java_column  # type: 
ignore[attr-defined]
+from pyspark.sql.column import _to_seq, _to_java_column
 
 from py4j.java_gateway import JavaObject  # type: ignore[import]
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to