This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 722ac1b [SPARK-36910][PYTHON] Inline type hints for python/pyspark/sql/types.py 722ac1b is described below commit 722ac1b8b7f86fdeedf20cc11c7f547e7038029c Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Fri Oct 15 12:07:17 2021 -0700 [SPARK-36910][PYTHON] Inline type hints for python/pyspark/sql/types.py ### What changes were proposed in this pull request? Inline type hints for python/pyspark/sql/types.py ### Why are the changes needed? Current stub files cannot support type checking for the function body. Inline type hints can type check the function body. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #34174 from xinrong-databricks/inline_types. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/pandas/frame.py | 6 +- python/pyspark/sql/dataframe.py | 3 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/types.py | 437 +++++++++++++++++++++++++--------------- python/pyspark/sql/types.pyi | 210 ------------------- 5 files changed, 280 insertions(+), 378 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 1f9a8d0..c22e077 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -6417,12 +6417,14 @@ defaultdict(<class 'list'>, {'col..., 'col...})] """ from pyspark.sql.types import _parse_datatype_string # type: ignore[attr-defined] + include_list: List[str] if not is_list_like(include): - include_list = [include] if include is not None else [] + include_list = [cast(str, include)] if include is not None else [] else: include_list = list(include) + exclude_list: List[str] if not is_list_like(exclude): - exclude_list = [exclude] if exclude is not None else [] + exclude_list = [cast(str, exclude)] if exclude is not None else [] else: exclude_list = list(exclude) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 223f041..7521ade 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -310,7 +310,8 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if self._schema is None: try: - self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + self._schema = cast( + StructType, _parse_datatype_json_string(self._jdf.schema().json())) except Exception as e: raise ValueError( "Unable to parse datatype from schema. %s" % e) from e diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 60e2d69..c8ed108 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -790,7 +790,7 @@ class SparkSession(SparkConversionMixin): raise TypeError("data is already a DataFrame") if isinstance(schema, str): - schema = _parse_datatype_string(schema) + schema = cast(Union[AtomicType, StructType, str], _parse_datatype_string(schema)) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ba31fc2..69ec96e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -25,13 +25,30 @@ import re import base64 from array import array import ctypes +from collections.abc import Iterable +from typing import ( + cast, + overload, + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Union, + Tuple, + Type, + TypeVar, +) from py4j.protocol import register_input_converter -from py4j.java_gateway import JavaClass +from py4j.java_gateway import JavaClass, JavaGateway, JavaObject -from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer +T = TypeVar("T") +U = TypeVar("U") + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "TimestampNTZType", "DecimalType", "DoubleType", "FloatType", @@ -42,34 +59,34 @@ __all__ = [ class DataType(object): """Base class for data types.""" - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @classmethod - def typeName(cls): + def typeName(cls) -> str: return cls.__name__[:-4].lower() - def simpleString(self): + def simpleString(self) -> str: return self.typeName() - def jsonValue(self): + def jsonValue(self) -> Union[str, Dict[str, Any]]: return self.typeName() - def json(self): + def json(self) -> str: return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - def needConversion(self): + def needConversion(self) -> bool: """ Does this type needs conversion between Python object and internal SQL object. @@ -77,13 +94,13 @@ class DataType(object): """ return False - def toInternal(self, obj): + def toInternal(self, obj: Any) -> Any: """ Converts a Python object into an internal SQL object. """ return obj - def fromInternal(self, obj): + def fromInternal(self, obj: Any) -> Any: """ Converts an internal SQL object into a native Python object. """ @@ -95,12 +112,13 @@ class DataType(object): class DataTypeSingleton(type): """Metaclass for DataType""" - _instances = {} + _instances: Dict[Type["DataTypeSingleton"], "DataTypeSingleton"] = {} - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(DataTypeSingleton, cls).__call__() - return cls._instances[cls] + def __call__(cls: Type[T]) -> T: # type: ignore[override, attr-defined] + if cls not in cls._instances: # type: ignore[attr-defined] + cls._instances[cls] = super( # type: ignore[misc, attr-defined] + DataTypeSingleton, cls).__call__() + return cls._instances[cls] # type: ignore[attr-defined] class NullType(DataType, metaclass=DataTypeSingleton): @@ -109,7 +127,7 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ @classmethod - def typeName(cls): + def typeName(cls) -> str: return 'void' @@ -158,14 +176,14 @@ class DateType(AtomicType, metaclass=DataTypeSingleton): EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - def needConversion(self): + def needConversion(self) -> bool: return True - def toInternal(self, d): + def toInternal(self, d: datetime.date) -> int: if d is not None: return d.toordinal() - self.EPOCH_ORDINAL - def fromInternal(self, v): + def fromInternal(self, v: int) -> datetime.date: if v is not None: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) @@ -174,16 +192,16 @@ class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type. """ - def needConversion(self): + def needConversion(self) -> bool: return True - def toInternal(self, dt): + def toInternal(self, dt: datetime.datetime) -> int: if dt is not None: seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts): + def fromInternal(self, ts: int) -> datetime.datetime: if ts is not None: # using int to avoid precision loss in float return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -193,19 +211,19 @@ class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information. """ - def needConversion(self): + def needConversion(self) -> bool: return True @classmethod - def typeName(cls): + def typeName(cls) -> str: return 'timestamp_ntz' - def toInternal(self, dt): + def toInternal(self, dt: datetime.datetime) -> int: if dt is not None: seconds = calendar.timegm(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts): + def fromInternal(self, ts: int) -> datetime.datetime: if ts is not None: # using int to avoid precision loss in float return datetime.datetime.utcfromtimestamp( @@ -232,18 +250,18 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision=10, scale=0): + def __init__(self, precision: int = 10, scale: int = 0): self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self): + def simpleString(self) -> str: return "decimal(%d,%d)" % (self.precision, self.scale) - def jsonValue(self): + def jsonValue(self) -> str: return "decimal(%d,%d)" % (self.precision, self.scale) - def __repr__(self): + def __repr__(self) -> str: return "DecimalType(%d,%d)" % (self.precision, self.scale) @@ -262,14 +280,14 @@ class FloatType(FractionalType, metaclass=DataTypeSingleton): class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """ - def simpleString(self): + def simpleString(self) -> str: return 'tinyint' class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """ - def simpleString(self): + def simpleString(self) -> str: return 'int' @@ -279,14 +297,14 @@ class LongType(IntegralType): If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please use :class:`DecimalType`. """ - def simpleString(self): + def simpleString(self) -> str: return 'bigint' class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """ - def simpleString(self): + def simpleString(self) -> str: return 'smallint' @@ -308,38 +326,38 @@ class ArrayType(DataType): False """ - def __init__(self, elementType, containsNull=True): + def __init__(self, elementType: DataType, containsNull: bool = True): assert isinstance(elementType, DataType),\ "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull - def simpleString(self): + def simpleString(self) -> str: return 'array<%s>' % self.elementType.simpleString() - def __repr__(self): + def __repr__(self) -> str: return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) - def jsonValue(self): + def jsonValue(self) -> Dict[str, Any]: return {"type": self.typeName(), "elementType": self.elementType.jsonValue(), "containsNull": self.containsNull} @classmethod - def fromJson(cls, json): + def fromJson(cls, json: Dict[str, Any]) -> "ArrayType": return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) - def needConversion(self): + def needConversion(self) -> bool: return self.elementType.needConversion() - def toInternal(self, obj): + def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj): + def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -371,7 +389,9 @@ class MapType(DataType): False """ - def __init__(self, keyType, valueType, valueContainsNull=True): + def __init__( + self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True + ): assert isinstance(keyType, DataType),\ "keyType %s should be an instance of %s" % (keyType, DataType) assert isinstance(valueType, DataType),\ @@ -380,35 +400,35 @@ class MapType(DataType): self.valueType = valueType self.valueContainsNull = valueContainsNull - def simpleString(self): + def simpleString(self) -> str: return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) - def __repr__(self): + def __repr__(self) -> str: return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) - def jsonValue(self): + def jsonValue(self) -> Dict[str, Any]: return {"type": self.typeName(), "keyType": self.keyType.jsonValue(), "valueType": self.valueType.jsonValue(), "valueContainsNull": self.valueContainsNull} @classmethod - def fromJson(cls, json): + def fromJson(cls, json: Dict[str, Any]) -> "MapType": return MapType(_parse_datatype_json_value(json["keyType"]), _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"]) - def needConversion(self): + def needConversion(self) -> bool: return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj): + def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) - def fromInternal(self, obj): + def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) @@ -439,7 +459,13 @@ class StructField(DataType): False """ - def __init__(self, name, dataType, nullable=True, metadata=None): + def __init__( + self, + name: str, + dataType: DataType, + nullable: bool = True, + metadata: Optional[Dict[str, Any]] = None, + ): assert isinstance(dataType, DataType),\ "dataType %s should be an instance of %s" % (dataType, DataType) assert isinstance(name, str), "field name %s should be a string" % (name) @@ -448,36 +474,36 @@ class StructField(DataType): self.nullable = nullable self.metadata = metadata or {} - def simpleString(self): + def simpleString(self) -> str: return '%s:%s' % (self.name, self.dataType.simpleString()) - def __repr__(self): + def __repr__(self) -> str: return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) - def jsonValue(self): + def jsonValue(self) -> Dict[str, Any]: return {"name": self.name, "type": self.dataType.jsonValue(), "nullable": self.nullable, "metadata": self.metadata} @classmethod - def fromJson(cls, json): + def fromJson(cls, json: Dict[str, Any]) -> "StructField": return StructField(json["name"], _parse_datatype_json_value(json["type"]), json["nullable"], json["metadata"]) - def needConversion(self): + def needConversion(self) -> bool: return self.dataType.needConversion() - def toInternal(self, obj): + def toInternal(self, obj: T) -> T: return self.dataType.toInternal(obj) - def fromInternal(self, obj): + def fromInternal(self, obj: T) -> T: return self.dataType.fromInternal(obj) - def typeName(self): + def typeName(self) -> str: # type: ignore[override] raise TypeError( "StructField does not have typeName. " "Use typeName on its type explicitly instead.") @@ -509,7 +535,7 @@ class StructType(DataType): >>> struct1 == struct2 False """ - def __init__(self, fields=None): + def __init__(self, fields: Optional[List[StructField]] = None): if not fields: self.fields = [] self.names = [] @@ -522,7 +548,27 @@ class StructType(DataType): self._needConversion = [f.needConversion() for f in self] self._needSerializeAnyField = any(self._needConversion) - def add(self, field, data_type=None, nullable=True, metadata=None): + @overload + def add( + self, + field: str, + data_type: Union[str, DataType], + nullable: bool = True, + metadata: Optional[Dict[str, Any]] = None, + ) -> "StructType": + ... + + @overload + def add(self, field: StructField) -> "StructType": + ... + + def add( + self, + field: Union[str, StructField], + data_type: Optional[Union[str, DataType]] = None, + nullable: bool = True, + metadata: Optional[Dict[str, Any]] = None, + ) -> "StructType": """ Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: @@ -581,15 +627,15 @@ class StructType(DataType): self._needSerializeAnyField = any(self._needConversion) return self - def __iter__(self): + def __iter__(self) -> Iterator[StructField]: """Iterate the fields""" return iter(self.fields) - def __len__(self): + def __len__(self) -> int: """Return the number of fields.""" return len(self.fields) - def __getitem__(self, key): + def __getitem__(self, key: Union[str, int]) -> StructField: """Access fields by name or slice.""" if isinstance(key, str): for field in self: @@ -606,22 +652,22 @@ class StructType(DataType): else: raise TypeError('StructType keys should be strings, integers or slices') - def simpleString(self): + def simpleString(self) -> str: return 'struct<%s>' % (','.join(f.simpleString() for f in self)) - def __repr__(self): + def __repr__(self) -> str: return ("StructType(List(%s))" % ",".join(str(field) for field in self)) - def jsonValue(self): + def jsonValue(self) -> Dict[str, Any]: return {"type": self.typeName(), "fields": [f.jsonValue() for f in self]} @classmethod - def fromJson(cls, json): + def fromJson(cls, json: Dict[str, Any]) -> "StructType": return StructType([StructField.fromJson(f) for f in json["fields"]]) - def fieldNames(self): + def fieldNames(self) -> List[str]: """ Returns all field names in a list. @@ -633,11 +679,11 @@ class StructType(DataType): """ return list(self.names) - def needConversion(self): + def needConversion(self) -> bool: # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj): + def toInternal(self, obj: Tuple) -> Tuple: if obj is None: return @@ -666,12 +712,14 @@ class StructType(DataType): else: raise ValueError("Unexpected tuple %r with StructType" % obj) - def fromInternal(self, obj): + def fromInternal(self, obj: Tuple) -> "Row": if obj is None: return if isinstance(obj, Row): # it's already converted by pickler return obj + + values: Union[Tuple, List] if self._needSerializeAnyField: # Only calling fromInternal function for fields that need conversion values = [f.fromInternal(v) if c else v @@ -688,71 +736,71 @@ class UserDefinedType(DataType): """ @classmethod - def typeName(cls): + def typeName(cls) -> str: return cls.__name__.lower() @classmethod - def sqlType(cls): + def sqlType(cls) -> DataType: """ Underlying SQL storage type for this UDT. """ raise NotImplementedError("UDT must implement sqlType().") @classmethod - def module(cls): + def module(cls) -> str: """ The Python module of the UDT. """ raise NotImplementedError("UDT must implement module().") @classmethod - def scalaUDT(cls): + def scalaUDT(cls) -> str: """ The class name of the paired Scala UDT (could be '', if there is no corresponding one). """ return '' - def needConversion(self): + def needConversion(self) -> bool: return True @classmethod - def _cachedSqlType(cls): + def _cachedSqlType(cls) -> DataType: """ Cache the sqlType() into class, because it's heavily used in `toInternal`. """ if not hasattr(cls, "_cached_sql_type"): - cls._cached_sql_type = cls.sqlType() - return cls._cached_sql_type + cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] + return cls._cached_sql_type # type: ignore[attr-defined] - def toInternal(self, obj): + def toInternal(self, obj: Any) -> Any: if obj is not None: return self._cachedSqlType().toInternal(self.serialize(obj)) - def fromInternal(self, obj): + def fromInternal(self, obj: Any) -> Any: v = self._cachedSqlType().fromInternal(obj) if v is not None: return self.deserialize(v) - def serialize(self, obj): + def serialize(self, obj: Any) -> Any: """ Converts a user-type object into a SQL datum. """ raise NotImplementedError("UDT must implement toInternal().") - def deserialize(self, datum): + def deserialize(self, datum: Any) -> Any: """ Converts a SQL datum into a user-type object. """ raise NotImplementedError("UDT must implement fromInternal().") - def simpleString(self): + def simpleString(self) -> str: return 'udt' - def json(self): + def json(self) -> str: return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - def jsonValue(self): + def jsonValue(self) -> Dict[str, Any]: if self.scalaUDT(): assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' schema = { @@ -773,7 +821,7 @@ class UserDefinedType(DataType): return schema @classmethod - def fromJson(cls, json): + def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType": pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] @@ -786,22 +834,23 @@ class UserDefinedType(DataType): UDT = getattr(m, pyClass) return UDT() - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return type(self) == type(other) -_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, - ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, - TimestampNTZType, NullType] -_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) +_atomic_types: List[Type[DataType]] = [ + StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, + TimestampNTZType, NullType] +_all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) +_complex_types: List[Type[DataType]] = [ArrayType, MapType, StructType] +_all_complex_types: Dict[str, Type[DataType]] = dict((v.typeName(), v) for v in _complex_types) _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") -def _parse_datatype_string(s): +def _parse_datatype_string(s: str) -> DataType: """ Parses the given data type string to a :class:`DataType`. The data type string format equals :class:`DataType.simpleString`, except that the top level struct type can omit @@ -843,13 +892,15 @@ def _parse_datatype_string(s): ... ParseException:... """ - sc = SparkContext._active_spark_context + from pyspark import SparkContext - def from_ddl_schema(type_str): + sc = SparkContext._active_spark_context # type: ignore[attr-defined] + + def from_ddl_schema(type_str: str) -> DataType: return _parse_datatype_json_string( sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()) - def from_ddl_datatype(type_str): + def from_ddl_datatype(type_str: str) -> DataType: return _parse_datatype_json_string( sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()) @@ -868,7 +919,7 @@ def _parse_datatype_string(s): raise e -def _parse_datatype_json_string(json_string): +def _parse_datatype_json_string(json_string: str) -> DataType: """Parses the given data type JSON string. Examples @@ -920,7 +971,7 @@ def _parse_datatype_json_string(json_string): return _parse_datatype_json_value(json.loads(json_string)) -def _parse_datatype_json_value(json_value): +def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: if not isinstance(json_value, dict): if json_value in _all_atomic_types.keys(): return _all_atomic_types[json_value]() @@ -930,13 +981,13 @@ def _parse_datatype_json_value(json_value): return TimestampNTZType() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) + return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr] else: raise ValueError("Could not parse datatype: %s" % json_value) else: tpe = json_value["type"] if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) + return _all_complex_types[tpe].fromJson(json_value) # type: ignore[attr-defined] elif tpe == 'udt': return UserDefinedType.fromJson(json_value) else: @@ -987,21 +1038,25 @@ _array_unsigned_int_typecode_ctype_mappings = { } -def _int_size_to_type(size): +def _int_size_to_type( + size: int, +) -> Optional[Union[Type[ByteType], Type[ShortType], Type[IntegerType], Type[LongType]]]: """ Return the Catalyst datatype from the size of integers. """ if size <= 8: return ByteType - if size <= 16: + elif size <= 16: return ShortType - if size <= 32: + elif size <= 32: return IntegerType - if size <= 64: + elif size <= 64: return LongType + else: + return None # The list of all supported array typecodes, is stored here -_array_type_mappings = { +_array_type_mappings: Dict[str, Type[DataType]] = { # Warning: Actual properties for float and double in C is not specified in C. # On almost every system supported by both python and JVM, they are IEEE 754 # single-precision binary floating-point format and IEEE 754 double-precision @@ -1032,7 +1087,11 @@ if sys.version_info[0] < 4: _array_type_mappings['u'] = StringType -def _infer_type(obj, infer_dict_as_struct=False, prefer_timestamp_ntz=False): +def _infer_type( + obj: Any, + infer_dict_as_struct: bool = False, + prefer_timestamp_ntz: bool = False, +) -> DataType: """Infer the DataType from obj """ if obj is None: @@ -1083,21 +1142,29 @@ def _infer_type(obj, infer_dict_as_struct=False, prefer_timestamp_ntz=False): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row, names=None, infer_dict_as_struct=False, prefer_timestamp_ntz=False): +def _infer_schema( + row: Any, + names: Optional[List[str]] = None, + infer_dict_as_struct: bool = False, + prefer_timestamp_ntz: bool = False, +) -> StructType: """Infer the schema from dict/namedtuple/object""" + items: Iterable[Tuple[str, Any]] if isinstance(row, dict): items = sorted(row.items()) elif isinstance(row, (tuple, list)): if hasattr(row, "__fields__"): # Row - items = zip(row.__fields__, tuple(row)) + items = zip(row.__fields__, tuple(row)) # type: ignore[union-attr] elif hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) + items = zip(row._fields, tuple(row)) # type: ignore[union-attr] else: if names is None: - names = ['_%d' % i for i in range(1, len(row) + 1)] + names = [ + '_%d' % i for i in range(1, len(row) + 1)] elif len(names) < len(row): - names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) + names.extend( + '_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1116,7 +1183,7 @@ def _infer_schema(row, names=None, infer_dict_as_struct=False, prefer_timestamp_ return StructType(fields) -def _has_nulltype(dt): +def _has_nulltype(dt: DataType) -> bool: """ Return whether there is a NullType in `dt` or not """ if isinstance(dt, StructType): return any(_has_nulltype(f.dataType) for f in dt.fields) @@ -1128,7 +1195,31 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b, name=None): +@overload +def _merge_type(a: StructType, b: StructType, name: Optional[str] = None) -> StructType: + ... + + +@overload +def _merge_type(a: ArrayType, b: ArrayType, name: Optional[str] = None) -> ArrayType: + ... + + +@overload +def _merge_type(a: MapType, b: MapType, name: Optional[str] = None) -> MapType: + ... + + +@overload +def _merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType: + ... + + +def _merge_type( + a: Union[StructType, ArrayType, MapType, DataType], + b: Union[StructType, ArrayType, MapType, DataType], + name: Optional[str] = None, +) -> Union[StructType, ArrayType, MapType, DataType]: if name is None: new_msg = lambda msg: msg new_name = lambda n: "field %s" % n @@ -1150,7 +1241,7 @@ def _merge_type(a, b, name=None): # same type if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) + nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields) fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name))) for f in a.fields] @@ -1161,18 +1252,19 @@ def _merge_type(a, b, name=None): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType, + return ArrayType(_merge_type(a.elementType, cast(ArrayType, b).elementType, name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), - _merge_type(a.valueType, b.valueType, name='value of map %s' % name), - True) + return MapType( + _merge_type(a.keyType, cast(MapType, b).keyType, name='key of map %s' % name), + _merge_type(a.valueType, cast(MapType, b).valueType, name='value of map %s' % name), + True) else: return a -def _need_converter(dataType): +def _need_converter(dataType: DataType) -> bool: if isinstance(dataType, StructType): return True elif isinstance(dataType, ArrayType): @@ -1185,7 +1277,7 @@ def _need_converter(dataType): return False -def _create_converter(dataType): +def _create_converter(dataType: DataType) -> Callable: """Create a converter to drop the names of fields in obj """ if not _need_converter(dataType): return lambda x: x @@ -1210,9 +1302,9 @@ def _create_converter(dataType): converters = [_create_converter(f.dataType) for f in dataType.fields] convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) - def convert_struct(obj): + def convert_struct(obj: Any) -> Optional[Tuple]: if obj is None: - return + return None if isinstance(obj, (tuple, list)): if convert_fields: @@ -1255,7 +1347,11 @@ _acceptable_types = { } -def _make_type_verifier(dataType, nullable=True, name=None): +def _make_type_verifier( + dataType: DataType, + nullable: bool = True, + name: Optional[str] = None, +) -> Callable: """ Make a verifier that checks the type of obj against dataType and raises a TypeError if they do not match. @@ -1318,7 +1414,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): new_msg = lambda msg: "%s: %s" % (name, msg) new_name = lambda n: "field %s in %s" % (n, name) - def verify_nullability(obj): + def verify_nullability(obj: Any) -> bool: if obj is None: if nullable: return True @@ -1329,13 +1425,13 @@ def _make_type_verifier(dataType, nullable=True, name=None): _type = type(dataType) - def assert_acceptable_types(obj): + def assert_acceptable_types(obj: Any) -> None: assert _type in _acceptable_types, \ new_msg("unknown datatype: %s for object %r" % (dataType, obj)) - def verify_acceptable_types(obj): + def verify_acceptable_types(obj: Any) -> None: # subclass of them can not be fromInternal in JVM - if type(obj) not in _acceptable_types[_type]: + if type(obj) not in _acceptable_types[_type]: # type: ignore[operator] raise TypeError(new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))) @@ -1346,7 +1442,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): elif isinstance(dataType, UserDefinedType): verifier = _make_type_verifier(dataType.sqlType(), name=name) - def verify_udf(obj): + def verify_udf(obj: Any) -> None: if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) verifier(dataType.toInternal(obj)) @@ -1354,7 +1450,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): verify_value = verify_udf elif isinstance(dataType, ByteType): - def verify_byte(obj): + def verify_byte(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -128 or obj > 127: @@ -1363,7 +1459,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): verify_value = verify_byte elif isinstance(dataType, ShortType): - def verify_short(obj): + def verify_short(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -32768 or obj > 32767: @@ -1372,7 +1468,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): verify_value = verify_short elif isinstance(dataType, IntegerType): - def verify_integer(obj): + def verify_integer(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -2147483648 or obj > 2147483647: @@ -1382,7 +1478,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): verify_value = verify_integer elif isinstance(dataType, LongType): - def verify_long(obj): + def verify_long(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -9223372036854775808 or obj > 9223372036854775807: @@ -1395,7 +1491,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): element_verifier = _make_type_verifier( dataType.elementType, dataType.containsNull, name="element in array %s" % name) - def verify_array(obj): + def verify_array(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) for i in obj: @@ -1408,7 +1504,7 @@ def _make_type_verifier(dataType, nullable=True, name=None): value_verifier = _make_type_verifier( dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) - def verify_map(obj): + def verify_map(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) for k, v in obj.items(): @@ -1420,10 +1516,11 @@ def _make_type_verifier(dataType, nullable=True, name=None): elif isinstance(dataType, StructType): verifiers = [] for f in dataType.fields: - verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifier = _make_type_verifier( + f.dataType, f.nullable, name=new_name(f.name)) # type: ignore[arg-type] verifiers.append((f.name, verifier)) - def verify_struct(obj): + def verify_struct(obj: Any) -> None: assert_acceptable_types(obj) if isinstance(obj, dict): @@ -1446,13 +1543,13 @@ def _make_type_verifier(dataType, nullable=True, name=None): verify_value = verify_struct else: - def verify_default(obj): + def verify_default(obj: Any) -> None: assert_acceptable_types(obj) verify_acceptable_types(obj) verify_value = verify_default - def verify(obj): + def verify(obj: Any) -> None: if not verify_nullability(obj): verify_value(obj) @@ -1460,11 +1557,13 @@ def _make_type_verifier(dataType, nullable=True, name=None): # This is used to unpickle a Row from JVM -def _create_row_inbound_converter(dataType): +def _create_row_inbound_converter(dataType: DataType) -> Callable: return lambda *a: dataType.fromInternal(a) -def _create_row(fields, values): +def _create_row( + fields: Union["Row", List[str]], values: Union[Tuple[Any, ...], List[Any]] +) -> "Row": row = Row(*values) row.__fields__ = fields return row @@ -1526,7 +1625,15 @@ class Row(tuple): True """ - def __new__(cls, *args, **kwargs): + @overload + def __new__(cls, *args: str) -> "Row": + ... + + @overload + def __new__(cls, **kwargs: Any) -> "Row": + ... + + def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: raise ValueError("Can not use both args " "and kwargs to create Row") @@ -1539,7 +1646,7 @@ class Row(tuple): # create row class or objects return tuple.__new__(cls, args) - def asDict(self, recursive=False): + def asDict(self, recursive: bool = False) -> Dict[str, Any]: """ Return as a dict @@ -1570,7 +1677,7 @@ class Row(tuple): raise TypeError("Cannot convert a Row class into dict") if recursive: - def conv(obj): + def conv(obj: Any) -> Any: if isinstance(obj, Row): return obj.asDict(True) elif isinstance(obj, list): @@ -1583,21 +1690,21 @@ class Row(tuple): else: return dict(zip(self.__fields__, self)) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: if hasattr(self, "__fields__"): return item in self.__fields__ else: return super(Row, self).__contains__(item) # let object acts like class - def __call__(self, *args): + def __call__(self, *args: Any) -> "Row": """create new Row object""" if len(args) > len(self): raise ValueError("Can not create Row with fields %s, expected %d values " "but got %s" % (self, len(self), args)) return _create_row(self, args) - def __getitem__(self, item): + def __getitem__(self, item: Any) -> Any: if isinstance(item, (int, slice)): return super(Row, self).__getitem__(item) try: @@ -1610,7 +1717,7 @@ class Row(tuple): except ValueError: raise ValueError(item) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: if item.startswith("__"): raise AttributeError(item) try: @@ -1623,19 +1730,21 @@ class Row(tuple): except ValueError: raise AttributeError(item) - def __setattr__(self, key, value): + def __setattr__(self, key: Any, value: Any) -> None: if key != '__fields__': raise RuntimeError("Row is read-only") self.__dict__[key] = value - def __reduce__(self): + def __reduce__( + self, + ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self) - def __repr__(self): + def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) @@ -1645,19 +1754,19 @@ class Row(tuple): class DateConverter(object): - def can_convert(self, obj): + def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.date) - def convert(self, obj, gateway_client): + def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: Date = JavaClass("java.sql.Date", gateway_client) return Date.valueOf(obj.strftime("%Y-%m-%d")) class DatetimeConverter(object): - def can_convert(self, obj): + def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.datetime) - def convert(self, obj, gateway_client): + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: Timestamp = JavaClass("java.sql.Timestamp", gateway_client) seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo else time.mktime(obj.timetuple())) @@ -1667,7 +1776,7 @@ class DatetimeConverter(object): class DatetimeNTZConverter(object): - def can_convert(self, obj): + def can_convert(self, obj: Any) -> bool: from pyspark.sql.utils import is_timestamp_ntz_preferred return ( @@ -1675,11 +1784,11 @@ class DatetimeNTZConverter(object): obj.tzinfo is None and is_timestamp_ntz_preferred()) - def convert(self, obj, gateway_client): + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: from pyspark import SparkContext seconds = calendar.timegm(obj.utctimetuple()) - jvm = SparkContext._jvm + jvm = SparkContext._jvm # type: ignore[attr-defined] return jvm.org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToLocalDateTime( int(seconds) * 1000000 + obj.microsecond ) @@ -1691,7 +1800,7 @@ register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) -def _test(): +def _test() -> None: import doctest from pyspark.context import SparkContext from pyspark.sql import SparkSession diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi deleted file mode 100644 index 58c646f..0000000 --- a/python/pyspark/sql/types.pyi +++ /dev/null @@ -1,210 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar -from py4j.java_gateway import JavaGateway, JavaObject -import datetime - -T = TypeVar("T") -U = TypeVar("U") - -class DataType: - def __hash__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... - @classmethod - def typeName(cls) -> str: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> Union[str, Dict[str, Any]]: ... - def json(self) -> str: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: Any) -> Any: ... - def fromInternal(self, obj: Any) -> Any: ... - -class DataTypeSingleton(type): - def __call__(cls: Type[T]) -> T: ... # type: ignore - -class NullType(DataType, metaclass=DataTypeSingleton): ... -class AtomicType(DataType): ... -class NumericType(AtomicType): ... -class IntegralType(NumericType, metaclass=DataTypeSingleton): ... -class FractionalType(NumericType): ... -class StringType(AtomicType, metaclass=DataTypeSingleton): ... -class BinaryType(AtomicType, metaclass=DataTypeSingleton): ... -class BooleanType(AtomicType, metaclass=DataTypeSingleton): ... - -class DateType(AtomicType, metaclass=DataTypeSingleton): - EPOCH_ORDINAL: int - def needConversion(self) -> bool: ... - def toInternal(self, d: datetime.date) -> int: ... - def fromInternal(self, v: int) -> datetime.date: ... - -class TimestampType(AtomicType, metaclass=DataTypeSingleton): - def needConversion(self) -> bool: ... - def toInternal(self, dt: datetime.datetime) -> int: ... - def fromInternal(self, ts: int) -> datetime.datetime: ... - -class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): - def needConversion(self) -> bool: ... - def toInternal(self, dt: datetime.datetime) -> int: ... - def fromInternal(self, ts: int) -> datetime.datetime: ... - -class DecimalType(FractionalType): - precision: int - scale: int - hasPrecisionInfo: bool - def __init__(self, precision: int = ..., scale: int = ...) -> None: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> str: ... - -class DoubleType(FractionalType, metaclass=DataTypeSingleton): ... -class FloatType(FractionalType, metaclass=DataTypeSingleton): ... - -class ByteType(IntegralType): - def simpleString(self) -> str: ... - -class IntegerType(IntegralType): - def simpleString(self) -> str: ... - -class LongType(IntegralType): - def simpleString(self) -> str: ... - -class ShortType(IntegralType): - def simpleString(self) -> str: ... - -class ArrayType(DataType): - elementType: DataType - containsNull: bool - def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> Dict[str, Any]: ... - @classmethod - def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: ... - def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: ... - -class MapType(DataType): - keyType: DataType - valueType: DataType - valueContainsNull: bool - def __init__( - self, keyType: DataType, valueType: DataType, valueContainsNull: bool = ... - ) -> None: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> Dict[str, Any]: ... - @classmethod - def fromJson(cls, json: Dict[str, Any]) -> MapType: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: ... - def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: ... - -class StructField(DataType): - name: str - dataType: DataType - nullable: bool - metadata: Dict[str, Any] - def __init__( - self, - name: str, - dataType: DataType, - nullable: bool = ..., - metadata: Optional[Dict[str, Any]] = ..., - ) -> None: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> Dict[str, Any]: ... - @classmethod - def fromJson(cls, json: Dict[str, Any]) -> StructField: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: T) -> T: ... - def fromInternal(self, obj: T) -> T: ... - -class StructType(DataType): - fields: List[StructField] - names: List[str] - def __init__(self, fields: Optional[List[StructField]] = ...) -> None: ... - @overload - def add( - self, - field: str, - data_type: Union[str, DataType], - nullable: bool = ..., - metadata: Optional[Dict[str, Any]] = ..., - ) -> StructType: ... - @overload - def add(self, field: StructField) -> StructType: ... - def __iter__(self) -> Iterator[StructField]: ... - def __len__(self) -> int: ... - def __getitem__(self, key: Union[str, int]) -> StructField: ... - def simpleString(self) -> str: ... - def jsonValue(self) -> Dict[str, Any]: ... - @classmethod - def fromJson(cls, json: Dict[str, Any]) -> StructType: ... - def fieldNames(self) -> List[str]: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: Tuple) -> Tuple: ... - def fromInternal(self, obj: Tuple) -> Row: ... - -class UserDefinedType(DataType): - @classmethod - def typeName(cls) -> str: ... - @classmethod - def sqlType(cls) -> DataType: ... - @classmethod - def module(cls) -> str: ... - @classmethod - def scalaUDT(cls) -> str: ... - def needConversion(self) -> bool: ... - def toInternal(self, obj: Any) -> Any: ... - def fromInternal(self, obj: Any) -> Any: ... - def serialize(self, obj: Any) -> Any: ... - def deserialize(self, datum: Any) -> Any: ... - def simpleString(self) -> str: ... - def json(self) -> str: ... - def jsonValue(self) -> Dict[str, Any]: ... - @classmethod - def fromJson(cls, json: Dict[str, Any]) -> UserDefinedType: ... - def __eq__(self, other: Any) -> bool: ... - -class Row(tuple): - @overload - def __new__(self, *args: str) -> Row: ... - @overload - def __new__(self, **kwargs: Any) -> Row: ... - @overload - def __init__(self, *args: str) -> None: ... - @overload - def __init__(self, **kwargs: Any) -> None: ... - def asDict(self, recursive: bool = ...) -> Dict[str, Any]: ... - def __contains__(self, item: Any) -> bool: ... - def __call__(self, *args: Any) -> Row: ... - def __getitem__(self, item: Any) -> Any: ... - def __getattr__(self, item: str) -> Any: ... - def __setattr__(self, key: Any, value: Any) -> None: ... - def __reduce__( - self, - ) -> Tuple[Callable[[List[str], List[Any]], Row], Tuple[List[str], Tuple]]: ... - -class DateConverter: - def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ... - -class DatetimeConverter: - def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org