This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 910c3733bfdd Revert "[SPARK-48415][PYTHON] Refactor TypeName to 
support parameterized datatypes"
910c3733bfdd is described below

commit 910c3733bfdd1a0f386137d48796e317f64f7f50
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 30 16:21:22 2024 +0800

    Revert "[SPARK-48415][PYTHON] Refactor TypeName to support parameterized 
datatypes"
    
    revert https://github.com/apache/spark/pull/46738
    
    Closes #46804 from zhengruifeng/revert_typename_oss.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_types.py | 133 ---------------------------------
 python/pyspark/sql/types.py            |  74 +++++++++++-------
 2 files changed, 47 insertions(+), 160 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index cc482b886e3a..80f2c0fcbc03 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -81,139 +81,6 @@ from pyspark.testing.utils import PySparkErrorTestUtils
 
 
 class TypesTestsMixin:
-    def test_class_method_type_name(self):
-        for dataType, expected in [
-            (StringType, "string"),
-            (CharType, "char"),
-            (VarcharType, "varchar"),
-            (BinaryType, "binary"),
-            (BooleanType, "boolean"),
-            (DecimalType, "decimal"),
-            (FloatType, "float"),
-            (DoubleType, "double"),
-            (ByteType, "byte"),
-            (ShortType, "short"),
-            (IntegerType, "integer"),
-            (LongType, "long"),
-            (DateType, "date"),
-            (TimestampType, "timestamp"),
-            (TimestampNTZType, "timestamp_ntz"),
-            (NullType, "void"),
-            (VariantType, "variant"),
-            (YearMonthIntervalType, "yearmonthinterval"),
-            (DayTimeIntervalType, "daytimeinterval"),
-            (CalendarIntervalType, "interval"),
-        ]:
-            self.assertEqual(dataType.typeName(), expected)
-
-    def test_instance_method_type_name(self):
-        for dataType, expected in [
-            (StringType(), "string"),
-            (CharType(5), "char(5)"),
-            (VarcharType(10), "varchar(10)"),
-            (BinaryType(), "binary"),
-            (BooleanType(), "boolean"),
-            (DecimalType(), "decimal(10,0)"),
-            (DecimalType(10, 2), "decimal(10,2)"),
-            (FloatType(), "float"),
-            (DoubleType(), "double"),
-            (ByteType(), "byte"),
-            (ShortType(), "short"),
-            (IntegerType(), "integer"),
-            (LongType(), "long"),
-            (DateType(), "date"),
-            (TimestampType(), "timestamp"),
-            (TimestampNTZType(), "timestamp_ntz"),
-            (NullType(), "void"),
-            (VariantType(), "variant"),
-            (YearMonthIntervalType(), "interval year to month"),
-            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
-            (
-                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
-                "interval year to month",
-            ),
-            (DayTimeIntervalType(), "interval day to second"),
-            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
-            (
-                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
-                "interval hour to second",
-            ),
-            (CalendarIntervalType(), "interval"),
-        ]:
-            self.assertEqual(dataType.typeName(), expected)
-
-    def test_simple_string(self):
-        for dataType, expected in [
-            (StringType(), "string"),
-            (CharType(5), "char(5)"),
-            (VarcharType(10), "varchar(10)"),
-            (BinaryType(), "binary"),
-            (BooleanType(), "boolean"),
-            (DecimalType(), "decimal(10,0)"),
-            (DecimalType(10, 2), "decimal(10,2)"),
-            (FloatType(), "float"),
-            (DoubleType(), "double"),
-            (ByteType(), "tinyint"),
-            (ShortType(), "smallint"),
-            (IntegerType(), "int"),
-            (LongType(), "bigint"),
-            (DateType(), "date"),
-            (TimestampType(), "timestamp"),
-            (TimestampNTZType(), "timestamp_ntz"),
-            (NullType(), "void"),
-            (VariantType(), "variant"),
-            (YearMonthIntervalType(), "interval year to month"),
-            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
-            (
-                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
-                "interval year to month",
-            ),
-            (DayTimeIntervalType(), "interval day to second"),
-            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
-            (
-                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
-                "interval hour to second",
-            ),
-            (CalendarIntervalType(), "interval"),
-        ]:
-            self.assertEqual(dataType.simpleString(), expected)
-
-    def test_json_value(self):
-        for dataType, expected in [
-            (StringType(), "string"),
-            (CharType(5), "char(5)"),
-            (VarcharType(10), "varchar(10)"),
-            (BinaryType(), "binary"),
-            (BooleanType(), "boolean"),
-            (DecimalType(), "decimal(10,0)"),
-            (DecimalType(10, 2), "decimal(10,2)"),
-            (FloatType(), "float"),
-            (DoubleType(), "double"),
-            (ByteType(), "byte"),
-            (ShortType(), "short"),
-            (IntegerType(), "integer"),
-            (LongType(), "long"),
-            (DateType(), "date"),
-            (TimestampType(), "timestamp"),
-            (TimestampNTZType(), "timestamp_ntz"),
-            (NullType(), "void"),
-            (VariantType(), "variant"),
-            (YearMonthIntervalType(), "interval year to month"),
-            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
-            (
-                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
-                "interval year to month",
-            ),
-            (DayTimeIntervalType(), "interval day to second"),
-            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
-            (
-                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
-                "interval hour to second",
-            ),
-            (CalendarIntervalType(), "interval"),
-        ]:
-            self.assertEqual(dataType.jsonValue(), expected)
-
     def test_apply_schema_to_row(self):
         df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
         df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c72ff72ce426..c0f60f839356 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -115,11 +115,7 @@ class DataType:
         return hash(str(self))
 
     def __eq__(self, other: Any) -> bool:
-        if isinstance(other, self.__class__):
-            self_dict = {k: v for k, v in self.__dict__.items() if k != 
"typeName"}
-            other_dict = {k: v for k, v in other.__dict__.items() if k != 
"typeName"}
-            return self_dict == other_dict
-        return False
+        return isinstance(other, self.__class__) and self.__dict__ == 
other.__dict__
 
     def __ne__(self, other: Any) -> bool:
         return not self.__eq__(other)
@@ -128,12 +124,6 @@ class DataType:
     def typeName(cls) -> str:
         return cls.__name__[:-4].lower()
 
-    # The classmethod 'typeName' is not always consistent with the Scala side, 
e.g.
-    # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)'
-    # This method is used in subclass initializer to replace 'typeName' if 
they are different.
-    def _type_name(self) -> str:
-        return 
self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower()
-
     def simpleString(self) -> str:
         return self.typeName()
 
@@ -225,6 +215,24 @@ class DataType:
         if isinstance(dataType, (ArrayType, StructType, MapType)):
             dataType._build_formatted_string(prefix, stringConcat, maxDepth - 
1)
 
+    # The method typeName() is not always the same as the Scala side.
+    # Add this helper method to make TreeString() compatible with Scala side.
+    @classmethod
+    def _get_jvm_type_name(cls, dataType: "DataType") -> str:
+        if isinstance(
+            dataType,
+            (
+                DecimalType,
+                CharType,
+                VarcharType,
+                DayTimeIntervalType,
+                YearMonthIntervalType,
+            ),
+        ):
+            return dataType.simpleString()
+        else:
+            return dataType.typeName()
+
 
 # This singleton pattern does not work with pickle, you will get
 # another object after pickle and unpickle
@@ -285,7 +293,6 @@ class StringType(AtomicType):
     providers = [providerSpark, providerICU]
 
     def __init__(self, collation: str = "UTF8_BINARY"):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         self.collation = collation
 
     @classmethod
@@ -295,7 +302,7 @@ class StringType(AtomicType):
             return StringType.providerSpark
         return StringType.providerICU
 
-    def _type_name(self) -> str:
+    def simpleString(self) -> str:
         if self.isUTF8BinaryCollation():
             return "string"
 
@@ -326,10 +333,12 @@ class CharType(AtomicType):
     """
 
     def __init__(self, length: int):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         self.length = length
 
-    def _type_name(self) -> str:
+    def simpleString(self) -> str:
+        return "char(%d)" % (self.length)
+
+    def jsonValue(self) -> str:
         return "char(%d)" % (self.length)
 
     def __repr__(self) -> str:
@@ -346,10 +355,12 @@ class VarcharType(AtomicType):
     """
 
     def __init__(self, length: int):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         self.length = length
 
-    def _type_name(self) -> str:
+    def simpleString(self) -> str:
+        return "varchar(%d)" % (self.length)
+
+    def jsonValue(self) -> str:
         return "varchar(%d)" % (self.length)
 
     def __repr__(self) -> str:
@@ -448,12 +459,14 @@ class DecimalType(FractionalType):
     """
 
     def __init__(self, precision: int = 10, scale: int = 0):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         self.precision = precision
         self.scale = scale
         self.hasPrecisionInfo = True  # this is a public API
 
-    def _type_name(self) -> str:
+    def simpleString(self) -> str:
+        return "decimal(%d,%d)" % (self.precision, self.scale)
+
+    def jsonValue(self) -> str:
         return "decimal(%d,%d)" % (self.precision, self.scale)
 
     def __repr__(self) -> str:
@@ -528,7 +541,6 @@ class DayTimeIntervalType(AnsiIntervalType):
     _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
 
     def __init__(self, startField: Optional[int] = None, endField: 
Optional[int] = None):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         if startField is None and endField is None:
             # Default matched to scala side.
             startField = DayTimeIntervalType.DAY
@@ -545,7 +557,7 @@ class DayTimeIntervalType(AnsiIntervalType):
         self.startField = startField
         self.endField = endField
 
-    def _type_name(self) -> str:
+    def _str_repr(self) -> str:
         fields = DayTimeIntervalType._fields
         start_field_name = fields[self.startField]
         end_field_name = fields[self.endField]
@@ -554,6 +566,10 @@ class DayTimeIntervalType(AnsiIntervalType):
         else:
             return "interval %s to %s" % (start_field_name, end_field_name)
 
+    simpleString = _str_repr
+
+    jsonValue = _str_repr
+
     def __repr__(self) -> str:
         return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
@@ -583,7 +599,6 @@ class YearMonthIntervalType(AnsiIntervalType):
     _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
 
     def __init__(self, startField: Optional[int] = None, endField: 
Optional[int] = None):
-        self.typeName = self._type_name  # type: ignore[method-assign]
         if startField is None and endField is None:
             # Default matched to scala side.
             startField = YearMonthIntervalType.YEAR
@@ -600,7 +615,7 @@ class YearMonthIntervalType(AnsiIntervalType):
         self.startField = startField
         self.endField = endField
 
-    def _type_name(self) -> str:
+    def _str_repr(self) -> str:
         fields = YearMonthIntervalType._fields
         start_field_name = fields[self.startField]
         end_field_name = fields[self.endField]
@@ -609,6 +624,10 @@ class YearMonthIntervalType(AnsiIntervalType):
         else:
             return "interval %s to %s" % (start_field_name, end_field_name)
 
+    simpleString = _str_repr
+
+    jsonValue = _str_repr
+
     def __repr__(self) -> str:
         return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
@@ -742,7 +761,7 @@ class ArrayType(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- element: {self.elementType.typeName()} "
+                f"{prefix}-- element: 
{DataType._get_jvm_type_name(self.elementType)} "
                 + f"(containsNull = {str(self.containsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -890,12 +909,12 @@ class MapType(DataType):
         maxDepth: int = JVM_INT_MAX,
     ) -> None:
         if maxDepth > 0:
-            stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
+            stringConcat.append(f"{prefix}-- key: 
{DataType._get_jvm_type_name(self.keyType)}\n")
             DataType._data_type_build_formatted_string(
                 self.keyType, f"{prefix}    |", stringConcat, maxDepth
             )
             stringConcat.append(
-                f"{prefix}-- value: {self.valueType.typeName()} "
+                f"{prefix}-- value: 
{DataType._get_jvm_type_name(self.valueType)} "
                 + f"(valueContainsNull = 
{str(self.valueContainsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -1058,7 +1077,8 @@ class StructField(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- {escape_meta_characters(self.name)}: 
{self.dataType.typeName()} "
+                f"{prefix}-- {escape_meta_characters(self.name)}: "
+                + f"{DataType._get_jvm_type_name(self.dataType)} "
                 + f"(nullable = {str(self.nullable).lower()})\n"
             )
             DataType._data_type_build_formatted_string(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to