This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 910c3733bfdd Revert "[SPARK-48415][PYTHON] Refactor TypeName to support parameterized datatypes" 910c3733bfdd is described below commit 910c3733bfdd1a0f386137d48796e317f64f7f50 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu May 30 16:21:22 2024 +0800 Revert "[SPARK-48415][PYTHON] Refactor TypeName to support parameterized datatypes" revert https://github.com/apache/spark/pull/46738 Closes #46804 from zhengruifeng/revert_typename_oss. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_types.py | 133 --------------------------------- python/pyspark/sql/types.py | 74 +++++++++++------- 2 files changed, 47 insertions(+), 160 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index cc482b886e3a..80f2c0fcbc03 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -81,139 +81,6 @@ from pyspark.testing.utils import PySparkErrorTestUtils class TypesTestsMixin: - def test_class_method_type_name(self): - for dataType, expected in [ - (StringType, "string"), - (CharType, "char"), - (VarcharType, "varchar"), - (BinaryType, "binary"), - (BooleanType, "boolean"), - (DecimalType, "decimal"), - (FloatType, "float"), - (DoubleType, "double"), - (ByteType, "byte"), - (ShortType, "short"), - (IntegerType, "integer"), - (LongType, "long"), - (DateType, "date"), - (TimestampType, "timestamp"), - (TimestampNTZType, "timestamp_ntz"), - (NullType, "void"), - (VariantType, "variant"), - (YearMonthIntervalType, "yearmonthinterval"), - (DayTimeIntervalType, "daytimeinterval"), - (CalendarIntervalType, "interval"), - ]: - self.assertEqual(dataType.typeName(), expected) - - def test_instance_method_type_name(self): - for dataType, expected in [ - (StringType(), "string"), - (CharType(5), "char(5)"), - (VarcharType(10), "varchar(10)"), - (BinaryType(), "binary"), - (BooleanType(), "boolean"), - (DecimalType(), "decimal(10,0)"), - (DecimalType(10, 2), "decimal(10,2)"), - (FloatType(), "float"), - (DoubleType(), "double"), - (ByteType(), "byte"), - (ShortType(), "short"), - (IntegerType(), "integer"), - (LongType(), "long"), - (DateType(), "date"), - (TimestampType(), "timestamp"), - (TimestampNTZType(), "timestamp_ntz"), - (NullType(), "void"), - (VariantType(), "variant"), - (YearMonthIntervalType(), "interval year to month"), - (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), - ( - YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), - "interval year to month", - ), - (DayTimeIntervalType(), "interval day to second"), - (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), - ( - DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), - "interval hour to second", - ), - (CalendarIntervalType(), "interval"), - ]: - self.assertEqual(dataType.typeName(), expected) - - def test_simple_string(self): - for dataType, expected in [ - (StringType(), "string"), - (CharType(5), "char(5)"), - (VarcharType(10), "varchar(10)"), - (BinaryType(), "binary"), - (BooleanType(), "boolean"), - (DecimalType(), "decimal(10,0)"), - (DecimalType(10, 2), "decimal(10,2)"), - (FloatType(), "float"), - (DoubleType(), "double"), - (ByteType(), "tinyint"), - (ShortType(), "smallint"), - (IntegerType(), "int"), - (LongType(), "bigint"), - (DateType(), "date"), - (TimestampType(), "timestamp"), - (TimestampNTZType(), "timestamp_ntz"), - (NullType(), "void"), - (VariantType(), "variant"), - (YearMonthIntervalType(), "interval year to month"), - (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), - ( - YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), - "interval year to month", - ), - (DayTimeIntervalType(), "interval day to second"), - (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), - ( - DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), - "interval hour to second", - ), - (CalendarIntervalType(), "interval"), - ]: - self.assertEqual(dataType.simpleString(), expected) - - def test_json_value(self): - for dataType, expected in [ - (StringType(), "string"), - (CharType(5), "char(5)"), - (VarcharType(10), "varchar(10)"), - (BinaryType(), "binary"), - (BooleanType(), "boolean"), - (DecimalType(), "decimal(10,0)"), - (DecimalType(10, 2), "decimal(10,2)"), - (FloatType(), "float"), - (DoubleType(), "double"), - (ByteType(), "byte"), - (ShortType(), "short"), - (IntegerType(), "integer"), - (LongType(), "long"), - (DateType(), "date"), - (TimestampType(), "timestamp"), - (TimestampNTZType(), "timestamp_ntz"), - (NullType(), "void"), - (VariantType(), "variant"), - (YearMonthIntervalType(), "interval year to month"), - (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), - ( - YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), - "interval year to month", - ), - (DayTimeIntervalType(), "interval day to second"), - (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), - ( - DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), - "interval hour to second", - ), - (CalendarIntervalType(), "interval"), - ]: - self.assertEqual(dataType.jsonValue(), expected) - def test_apply_schema_to_row(self): df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c72ff72ce426..c0f60f839356 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -115,11 +115,7 @@ class DataType: return hash(str(self)) def __eq__(self, other: Any) -> bool: - if isinstance(other, self.__class__): - self_dict = {k: v for k, v in self.__dict__.items() if k != "typeName"} - other_dict = {k: v for k, v in other.__dict__.items() if k != "typeName"} - return self_dict == other_dict - return False + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -128,12 +124,6 @@ class DataType: def typeName(cls) -> str: return cls.__name__[:-4].lower() - # The classmethod 'typeName' is not always consistent with the Scala side, e.g. - # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)' - # This method is used in subclass initializer to replace 'typeName' if they are different. - def _type_name(self) -> str: - return self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower() - def simpleString(self) -> str: return self.typeName() @@ -225,6 +215,24 @@ class DataType: if isinstance(dataType, (ArrayType, StructType, MapType)): dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) + # The method typeName() is not always the same as the Scala side. + # Add this helper method to make TreeString() compatible with Scala side. + @classmethod + def _get_jvm_type_name(cls, dataType: "DataType") -> str: + if isinstance( + dataType, + ( + DecimalType, + CharType, + VarcharType, + DayTimeIntervalType, + YearMonthIntervalType, + ), + ): + return dataType.simpleString() + else: + return dataType.typeName() + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -285,7 +293,6 @@ class StringType(AtomicType): providers = [providerSpark, providerICU] def __init__(self, collation: str = "UTF8_BINARY"): - self.typeName = self._type_name # type: ignore[method-assign] self.collation = collation @classmethod @@ -295,7 +302,7 @@ class StringType(AtomicType): return StringType.providerSpark return StringType.providerICU - def _type_name(self) -> str: + def simpleString(self) -> str: if self.isUTF8BinaryCollation(): return "string" @@ -326,10 +333,12 @@ class CharType(AtomicType): """ def __init__(self, length: int): - self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def _type_name(self) -> str: + def simpleString(self) -> str: + return "char(%d)" % (self.length) + + def jsonValue(self) -> str: return "char(%d)" % (self.length) def __repr__(self) -> str: @@ -346,10 +355,12 @@ class VarcharType(AtomicType): """ def __init__(self, length: int): - self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def _type_name(self) -> str: + def simpleString(self) -> str: + return "varchar(%d)" % (self.length) + + def jsonValue(self) -> str: return "varchar(%d)" % (self.length) def __repr__(self) -> str: @@ -448,12 +459,14 @@ class DecimalType(FractionalType): """ def __init__(self, precision: int = 10, scale: int = 0): - self.typeName = self._type_name # type: ignore[method-assign] self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def _type_name(self) -> str: + def simpleString(self) -> str: + return "decimal(%d,%d)" % (self.precision, self.scale) + + def jsonValue(self) -> str: return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self) -> str: @@ -528,7 +541,6 @@ class DayTimeIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): - self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = DayTimeIntervalType.DAY @@ -545,7 +557,7 @@ class DayTimeIntervalType(AnsiIntervalType): self.startField = startField self.endField = endField - def _type_name(self) -> str: + def _str_repr(self) -> str: fields = DayTimeIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -554,6 +566,10 @@ class DayTimeIntervalType(AnsiIntervalType): else: return "interval %s to %s" % (start_field_name, end_field_name) + simpleString = _str_repr + + jsonValue = _str_repr + def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -583,7 +599,6 @@ class YearMonthIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): - self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = YearMonthIntervalType.YEAR @@ -600,7 +615,7 @@ class YearMonthIntervalType(AnsiIntervalType): self.startField = startField self.endField = endField - def _type_name(self) -> str: + def _str_repr(self) -> str: fields = YearMonthIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -609,6 +624,10 @@ class YearMonthIntervalType(AnsiIntervalType): else: return "interval %s to %s" % (start_field_name, end_field_name) + simpleString = _str_repr + + jsonValue = _str_repr + def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -742,7 +761,7 @@ class ArrayType(DataType): ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- element: {self.elementType.typeName()} " + f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} " + f"(containsNull = {str(self.containsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -890,12 +909,12 @@ class MapType(DataType): maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: - stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") + stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n") DataType._data_type_build_formatted_string( self.keyType, f"{prefix} |", stringConcat, maxDepth ) stringConcat.append( - f"{prefix}-- value: {self.valueType.typeName()} " + f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} " + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -1058,7 +1077,8 @@ class StructField(DataType): ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + f"{prefix}-- {escape_meta_characters(self.name)}: " + + f"{DataType._get_jvm_type_name(self.dataType)} " + f"(nullable = {str(self.nullable).lower()})\n" ) DataType._data_type_build_formatted_string( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org