This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fc1435d14d09 [SPARK-48415][PYTHON] Refactor `TypeName` to support parameterized datatypes fc1435d14d09 is described below commit fc1435d14d090b792a0f19372d6b11c7ff026372 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue May 28 08:39:28 2024 +0800 [SPARK-48415][PYTHON] Refactor `TypeName` to support parameterized datatypes ### What changes were proposed in this pull request? 1, refactor instance method `TypeName` to support parameterized datatypes 2, remove redundant simpleString/jsonValue methods, since they are type name by default. ### Why are the changes needed? to be consistent with the Scala side ### Does this PR introduce _any_ user-facing change? type names changes: `CharType(10)`: `char` -> `char(10)` `VarcharType(10)`: `varchar` -> `varchar(10)` `DecimalType(10, 2)`: `decimal` -> `decimal(10,2)` `DayTimeIntervalType(DAY, HOUR)`: `daytimeinterval` -> `interval day to hour` `YearMonthIntervalType(YEAR, MONTH)`: `yearmonthinterval` -> `interval year to month` ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46738 from zhengruifeng/py_type_name. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_types.py | 133 +++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 74 +++++++----------- 2 files changed, 160 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 80f2c0fcbc03..cc482b886e3a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -81,6 +81,139 @@ from pyspark.testing.utils import PySparkErrorTestUtils class TypesTestsMixin: + def test_class_method_type_name(self): + for dataType, expected in [ + (StringType, "string"), + (CharType, "char"), + (VarcharType, "varchar"), + (BinaryType, "binary"), + (BooleanType, "boolean"), + (DecimalType, "decimal"), + (FloatType, "float"), + (DoubleType, "double"), + (ByteType, "byte"), + (ShortType, "short"), + (IntegerType, "integer"), + (LongType, "long"), + (DateType, "date"), + (TimestampType, "timestamp"), + (TimestampNTZType, "timestamp_ntz"), + (NullType, "void"), + (VariantType, "variant"), + (YearMonthIntervalType, "yearmonthinterval"), + (DayTimeIntervalType, "daytimeinterval"), + (CalendarIntervalType, "interval"), + ]: + self.assertEqual(dataType.typeName(), expected) + + def test_instance_method_type_name(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "byte"), + (ShortType(), "short"), + (IntegerType(), "integer"), + (LongType(), "long"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.typeName(), expected) + + def test_simple_string(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "tinyint"), + (ShortType(), "smallint"), + (IntegerType(), "int"), + (LongType(), "bigint"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.simpleString(), expected) + + def test_json_value(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "byte"), + (ShortType(), "short"), + (IntegerType(), "integer"), + (LongType(), "long"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.jsonValue(), expected) + def test_apply_schema_to_row(self): df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b9db59e0a58a..563c63f5dfb1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -115,7 +115,11 @@ class DataType: return hash(str(self)) def __eq__(self, other: Any) -> bool: - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + if isinstance(other, self.__class__): + self_dict = {k: v for k, v in self.__dict__.items() if k != "typeName"} + other_dict = {k: v for k, v in other.__dict__.items() if k != "typeName"} + return self_dict == other_dict + return False def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -124,6 +128,12 @@ class DataType: def typeName(cls) -> str: return cls.__name__[:-4].lower() + # The classmethod 'typeName' is not always consistent with the Scala side, e.g. + # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)' + # This method is used in subclass initializer to replace 'typeName' if they are different. + def _type_name(self) -> str: + return self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower() + def simpleString(self) -> str: return self.typeName() @@ -215,24 +225,6 @@ class DataType: if isinstance(dataType, (ArrayType, StructType, MapType)): dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) - # The method typeName() is not always the same as the Scala side. - # Add this helper method to make TreeString() compatible with Scala side. - @classmethod - def _get_jvm_type_name(cls, dataType: "DataType") -> str: - if isinstance( - dataType, - ( - DecimalType, - CharType, - VarcharType, - DayTimeIntervalType, - YearMonthIntervalType, - ), - ): - return dataType.simpleString() - else: - return dataType.typeName() - # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -294,6 +286,7 @@ class StringType(AtomicType): providers = [providerSpark, providerICU] def __init__(self, collation: Optional[str] = None): + self.typeName = self._type_name # type: ignore[method-assign] self.collationId = 0 if collation is None else self.collationNameToId(collation) @classmethod @@ -315,7 +308,7 @@ class StringType(AtomicType): return StringType.providerSpark return StringType.providerICU - def simpleString(self) -> str: + def _type_name(self) -> str: if self.isUTF8BinaryCollation(): return "string" @@ -348,12 +341,10 @@ class CharType(AtomicType): """ def __init__(self, length: int): + self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def simpleString(self) -> str: - return "char(%d)" % (self.length) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "char(%d)" % (self.length) def __repr__(self) -> str: @@ -370,12 +361,10 @@ class VarcharType(AtomicType): """ def __init__(self, length: int): + self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def simpleString(self) -> str: - return "varchar(%d)" % (self.length) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "varchar(%d)" % (self.length) def __repr__(self) -> str: @@ -474,14 +463,12 @@ class DecimalType(FractionalType): """ def __init__(self, precision: int = 10, scale: int = 0): + self.typeName = self._type_name # type: ignore[method-assign] self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self) -> str: - return "decimal(%d,%d)" % (self.precision, self.scale) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self) -> str: @@ -556,6 +543,7 @@ class DayTimeIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = DayTimeIntervalType.DAY @@ -572,7 +560,7 @@ class DayTimeIntervalType(AnsiIntervalType): self.startField = startField self.endField = endField - def _str_repr(self) -> str: + def _type_name(self) -> str: fields = DayTimeIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -581,10 +569,6 @@ class DayTimeIntervalType(AnsiIntervalType): else: return "interval %s to %s" % (start_field_name, end_field_name) - simpleString = _str_repr - - jsonValue = _str_repr - def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -614,6 +598,7 @@ class YearMonthIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = YearMonthIntervalType.YEAR @@ -630,7 +615,7 @@ class YearMonthIntervalType(AnsiIntervalType): self.startField = startField self.endField = endField - def _str_repr(self) -> str: + def _type_name(self) -> str: fields = YearMonthIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -639,10 +624,6 @@ class YearMonthIntervalType(AnsiIntervalType): else: return "interval %s to %s" % (start_field_name, end_field_name) - simpleString = _str_repr - - jsonValue = _str_repr - def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -776,7 +757,7 @@ class ArrayType(DataType): ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} " + f"{prefix}-- element: {self.elementType.typeName()} " + f"(containsNull = {str(self.containsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -924,12 +905,12 @@ class MapType(DataType): maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: - stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n") + stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") DataType._data_type_build_formatted_string( self.keyType, f"{prefix} |", stringConcat, maxDepth ) stringConcat.append( - f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} " + f"{prefix}-- value: {self.valueType.typeName()} " + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -1092,8 +1073,7 @@ class StructField(DataType): ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- {escape_meta_characters(self.name)}: " - + f"{DataType._get_jvm_type_name(self.dataType)} " + f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + f"(nullable = {str(self.nullable).lower()})\n" ) DataType._data_type_build_formatted_string( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org