This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fc1435d14d09 [SPARK-48415][PYTHON] Refactor `TypeName` to support 
parameterized datatypes
fc1435d14d09 is described below

commit fc1435d14d090b792a0f19372d6b11c7ff026372
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue May 28 08:39:28 2024 +0800

    [SPARK-48415][PYTHON] Refactor `TypeName` to support parameterized datatypes
    
    ### What changes were proposed in this pull request?
    1, refactor instance method `TypeName` to support parameterized datatypes
    2, remove redundant simpleString/jsonValue methods, since they are type 
name by default.
    
    ### Why are the changes needed?
    to be consistent with the Scala side
    
    ### Does this PR introduce _any_ user-facing change?
    
    type names changes:
    `CharType(10)`: `char` -> `char(10)`
    `VarcharType(10)`: `varchar` -> `varchar(10)`
    `DecimalType(10, 2)`: `decimal` -> `decimal(10,2)`
    `DayTimeIntervalType(DAY, HOUR)`: `daytimeinterval` -> `interval day to 
hour`
    `YearMonthIntervalType(YEAR, MONTH)`: `yearmonthinterval` -> `interval year 
to month`
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46738 from zhengruifeng/py_type_name.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_types.py | 133 +++++++++++++++++++++++++++++++++
 python/pyspark/sql/types.py            |  74 +++++++-----------
 2 files changed, 160 insertions(+), 47 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 80f2c0fcbc03..cc482b886e3a 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -81,6 +81,139 @@ from pyspark.testing.utils import PySparkErrorTestUtils
 
 
 class TypesTestsMixin:
+    def test_class_method_type_name(self):
+        for dataType, expected in [
+            (StringType, "string"),
+            (CharType, "char"),
+            (VarcharType, "varchar"),
+            (BinaryType, "binary"),
+            (BooleanType, "boolean"),
+            (DecimalType, "decimal"),
+            (FloatType, "float"),
+            (DoubleType, "double"),
+            (ByteType, "byte"),
+            (ShortType, "short"),
+            (IntegerType, "integer"),
+            (LongType, "long"),
+            (DateType, "date"),
+            (TimestampType, "timestamp"),
+            (TimestampNTZType, "timestamp_ntz"),
+            (NullType, "void"),
+            (VariantType, "variant"),
+            (YearMonthIntervalType, "yearmonthinterval"),
+            (DayTimeIntervalType, "daytimeinterval"),
+            (CalendarIntervalType, "interval"),
+        ]:
+            self.assertEqual(dataType.typeName(), expected)
+
+    def test_instance_method_type_name(self):
+        for dataType, expected in [
+            (StringType(), "string"),
+            (CharType(5), "char(5)"),
+            (VarcharType(10), "varchar(10)"),
+            (BinaryType(), "binary"),
+            (BooleanType(), "boolean"),
+            (DecimalType(), "decimal(10,0)"),
+            (DecimalType(10, 2), "decimal(10,2)"),
+            (FloatType(), "float"),
+            (DoubleType(), "double"),
+            (ByteType(), "byte"),
+            (ShortType(), "short"),
+            (IntegerType(), "integer"),
+            (LongType(), "long"),
+            (DateType(), "date"),
+            (TimestampType(), "timestamp"),
+            (TimestampNTZType(), "timestamp_ntz"),
+            (NullType(), "void"),
+            (VariantType(), "variant"),
+            (YearMonthIntervalType(), "interval year to month"),
+            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
+            (
+                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
+                "interval year to month",
+            ),
+            (DayTimeIntervalType(), "interval day to second"),
+            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+            (
+                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
+                "interval hour to second",
+            ),
+            (CalendarIntervalType(), "interval"),
+        ]:
+            self.assertEqual(dataType.typeName(), expected)
+
+    def test_simple_string(self):
+        for dataType, expected in [
+            (StringType(), "string"),
+            (CharType(5), "char(5)"),
+            (VarcharType(10), "varchar(10)"),
+            (BinaryType(), "binary"),
+            (BooleanType(), "boolean"),
+            (DecimalType(), "decimal(10,0)"),
+            (DecimalType(10, 2), "decimal(10,2)"),
+            (FloatType(), "float"),
+            (DoubleType(), "double"),
+            (ByteType(), "tinyint"),
+            (ShortType(), "smallint"),
+            (IntegerType(), "int"),
+            (LongType(), "bigint"),
+            (DateType(), "date"),
+            (TimestampType(), "timestamp"),
+            (TimestampNTZType(), "timestamp_ntz"),
+            (NullType(), "void"),
+            (VariantType(), "variant"),
+            (YearMonthIntervalType(), "interval year to month"),
+            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
+            (
+                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
+                "interval year to month",
+            ),
+            (DayTimeIntervalType(), "interval day to second"),
+            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+            (
+                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
+                "interval hour to second",
+            ),
+            (CalendarIntervalType(), "interval"),
+        ]:
+            self.assertEqual(dataType.simpleString(), expected)
+
+    def test_json_value(self):
+        for dataType, expected in [
+            (StringType(), "string"),
+            (CharType(5), "char(5)"),
+            (VarcharType(10), "varchar(10)"),
+            (BinaryType(), "binary"),
+            (BooleanType(), "boolean"),
+            (DecimalType(), "decimal(10,0)"),
+            (DecimalType(10, 2), "decimal(10,2)"),
+            (FloatType(), "float"),
+            (DoubleType(), "double"),
+            (ByteType(), "byte"),
+            (ShortType(), "short"),
+            (IntegerType(), "integer"),
+            (LongType(), "long"),
+            (DateType(), "date"),
+            (TimestampType(), "timestamp"),
+            (TimestampNTZType(), "timestamp_ntz"),
+            (NullType(), "void"),
+            (VariantType(), "variant"),
+            (YearMonthIntervalType(), "interval year to month"),
+            (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval 
year"),
+            (
+                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
+                "interval year to month",
+            ),
+            (DayTimeIntervalType(), "interval day to second"),
+            (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+            (
+                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
+                "interval hour to second",
+            ),
+            (CalendarIntervalType(), "interval"),
+        ]:
+            self.assertEqual(dataType.jsonValue(), expected)
+
     def test_apply_schema_to_row(self):
         df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
         df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index b9db59e0a58a..563c63f5dfb1 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -115,7 +115,11 @@ class DataType:
         return hash(str(self))
 
     def __eq__(self, other: Any) -> bool:
-        return isinstance(other, self.__class__) and self.__dict__ == 
other.__dict__
+        if isinstance(other, self.__class__):
+            self_dict = {k: v for k, v in self.__dict__.items() if k != 
"typeName"}
+            other_dict = {k: v for k, v in other.__dict__.items() if k != 
"typeName"}
+            return self_dict == other_dict
+        return False
 
     def __ne__(self, other: Any) -> bool:
         return not self.__eq__(other)
@@ -124,6 +128,12 @@ class DataType:
     def typeName(cls) -> str:
         return cls.__name__[:-4].lower()
 
+    # The classmethod 'typeName' is not always consistent with the Scala side, 
e.g.
+    # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)'
+    # This method is used in subclass initializer to replace 'typeName' if 
they are different.
+    def _type_name(self) -> str:
+        return 
self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower()
+
     def simpleString(self) -> str:
         return self.typeName()
 
@@ -215,24 +225,6 @@ class DataType:
         if isinstance(dataType, (ArrayType, StructType, MapType)):
             dataType._build_formatted_string(prefix, stringConcat, maxDepth - 
1)
 
-    # The method typeName() is not always the same as the Scala side.
-    # Add this helper method to make TreeString() compatible with Scala side.
-    @classmethod
-    def _get_jvm_type_name(cls, dataType: "DataType") -> str:
-        if isinstance(
-            dataType,
-            (
-                DecimalType,
-                CharType,
-                VarcharType,
-                DayTimeIntervalType,
-                YearMonthIntervalType,
-            ),
-        ):
-            return dataType.simpleString()
-        else:
-            return dataType.typeName()
-
 
 # This singleton pattern does not work with pickle, you will get
 # another object after pickle and unpickle
@@ -294,6 +286,7 @@ class StringType(AtomicType):
     providers = [providerSpark, providerICU]
 
     def __init__(self, collation: Optional[str] = None):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         self.collationId = 0 if collation is None else 
self.collationNameToId(collation)
 
     @classmethod
@@ -315,7 +308,7 @@ class StringType(AtomicType):
             return StringType.providerSpark
         return StringType.providerICU
 
-    def simpleString(self) -> str:
+    def _type_name(self) -> str:
         if self.isUTF8BinaryCollation():
             return "string"
 
@@ -348,12 +341,10 @@ class CharType(AtomicType):
     """
 
     def __init__(self, length: int):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         self.length = length
 
-    def simpleString(self) -> str:
-        return "char(%d)" % (self.length)
-
-    def jsonValue(self) -> str:
+    def _type_name(self) -> str:
         return "char(%d)" % (self.length)
 
     def __repr__(self) -> str:
@@ -370,12 +361,10 @@ class VarcharType(AtomicType):
     """
 
     def __init__(self, length: int):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         self.length = length
 
-    def simpleString(self) -> str:
-        return "varchar(%d)" % (self.length)
-
-    def jsonValue(self) -> str:
+    def _type_name(self) -> str:
         return "varchar(%d)" % (self.length)
 
     def __repr__(self) -> str:
@@ -474,14 +463,12 @@ class DecimalType(FractionalType):
     """
 
     def __init__(self, precision: int = 10, scale: int = 0):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         self.precision = precision
         self.scale = scale
         self.hasPrecisionInfo = True  # this is a public API
 
-    def simpleString(self) -> str:
-        return "decimal(%d,%d)" % (self.precision, self.scale)
-
-    def jsonValue(self) -> str:
+    def _type_name(self) -> str:
         return "decimal(%d,%d)" % (self.precision, self.scale)
 
     def __repr__(self) -> str:
@@ -556,6 +543,7 @@ class DayTimeIntervalType(AnsiIntervalType):
     _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
 
     def __init__(self, startField: Optional[int] = None, endField: 
Optional[int] = None):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         if startField is None and endField is None:
             # Default matched to scala side.
             startField = DayTimeIntervalType.DAY
@@ -572,7 +560,7 @@ class DayTimeIntervalType(AnsiIntervalType):
         self.startField = startField
         self.endField = endField
 
-    def _str_repr(self) -> str:
+    def _type_name(self) -> str:
         fields = DayTimeIntervalType._fields
         start_field_name = fields[self.startField]
         end_field_name = fields[self.endField]
@@ -581,10 +569,6 @@ class DayTimeIntervalType(AnsiIntervalType):
         else:
             return "interval %s to %s" % (start_field_name, end_field_name)
 
-    simpleString = _str_repr
-
-    jsonValue = _str_repr
-
     def __repr__(self) -> str:
         return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
@@ -614,6 +598,7 @@ class YearMonthIntervalType(AnsiIntervalType):
     _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
 
     def __init__(self, startField: Optional[int] = None, endField: 
Optional[int] = None):
+        self.typeName = self._type_name  # type: ignore[method-assign]
         if startField is None and endField is None:
             # Default matched to scala side.
             startField = YearMonthIntervalType.YEAR
@@ -630,7 +615,7 @@ class YearMonthIntervalType(AnsiIntervalType):
         self.startField = startField
         self.endField = endField
 
-    def _str_repr(self) -> str:
+    def _type_name(self) -> str:
         fields = YearMonthIntervalType._fields
         start_field_name = fields[self.startField]
         end_field_name = fields[self.endField]
@@ -639,10 +624,6 @@ class YearMonthIntervalType(AnsiIntervalType):
         else:
             return "interval %s to %s" % (start_field_name, end_field_name)
 
-    simpleString = _str_repr
-
-    jsonValue = _str_repr
-
     def __repr__(self) -> str:
         return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
@@ -776,7 +757,7 @@ class ArrayType(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- element: 
{DataType._get_jvm_type_name(self.elementType)} "
+                f"{prefix}-- element: {self.elementType.typeName()} "
                 + f"(containsNull = {str(self.containsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -924,12 +905,12 @@ class MapType(DataType):
         maxDepth: int = JVM_INT_MAX,
     ) -> None:
         if maxDepth > 0:
-            stringConcat.append(f"{prefix}-- key: 
{DataType._get_jvm_type_name(self.keyType)}\n")
+            stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
             DataType._data_type_build_formatted_string(
                 self.keyType, f"{prefix}    |", stringConcat, maxDepth
             )
             stringConcat.append(
-                f"{prefix}-- value: 
{DataType._get_jvm_type_name(self.valueType)} "
+                f"{prefix}-- value: {self.valueType.typeName()} "
                 + f"(valueContainsNull = 
{str(self.valueContainsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -1092,8 +1073,7 @@ class StructField(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- {escape_meta_characters(self.name)}: "
-                + f"{DataType._get_jvm_type_name(self.dataType)} "
+                f"{prefix}-- {escape_meta_characters(self.name)}: 
{self.dataType.typeName()} "
                 + f"(nullable = {str(self.nullable).lower()})\n"
             )
             DataType._data_type_build_formatted_string(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to