This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 00d7094dc30 [SPARK-39809][PYTHON] Support CharType in PySpark
00d7094dc30 is described below

commit 00d7094dc3024ae594605b311dcc55e95d277d5f
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Jul 19 10:22:04 2022 +0900

    [SPARK-39809][PYTHON] Support CharType in PySpark
    
    ### What changes were proposed in this pull request?
    Support CharType in PySpark
    
    ### Why are the changes needed?
    for function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new type added
    
    ### How was this patch tested?
    added UT
    
    Closes #37215 from zhengruifeng/py_add_char.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_types.py | 26 ++++++++++++++++++---
 python/pyspark/sql/types.py            | 42 +++++++++++++++++++++++++++++++---
 2 files changed, 62 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 218cfc413db..b1609417a0c 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -38,6 +38,7 @@ from pyspark.sql.types import (
     DayTimeIntervalType,
     MapType,
     StringType,
+    CharType,
     VarcharType,
     StructType,
     StructField,
@@ -740,9 +741,12 @@ class TypesTests(ReusedSQLTestCase):
         from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
 
         for k, t in _all_atomic_types.items():
-            if k != "varchar":
+            if k != "varchar" and k != "char":
                 self.assertEqual(t(), _parse_datatype_string(k))
         self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+        self.assertEqual(CharType(1), _parse_datatype_string("char(1)"))
+        self.assertEqual(CharType(10), _parse_datatype_string("char( 10   )"))
+        self.assertEqual(CharType(11), _parse_datatype_string("char( 11)"))
         self.assertEqual(VarcharType(1), _parse_datatype_string("varchar(1)"))
         self.assertEqual(VarcharType(10), _parse_datatype_string("varchar( 10  
 )"))
         self.assertEqual(VarcharType(11), _parse_datatype_string("varchar( 
11)"))
@@ -1033,6 +1037,7 @@ class TypesTests(ReusedSQLTestCase):
         instances = [
             NullType(),
             StringType(),
+            CharType(10),
             VarcharType(10),
             BinaryType(),
             BooleanType(),
@@ -1138,6 +1143,15 @@ class DataTypeTests(unittest.TestCase):
         t3 = DecimalType(8)
         self.assertNotEqual(t2, t3)
 
+    def test_char_type(self):
+        v1 = CharType(10)
+        v2 = CharType(20)
+        self.assertTrue(v2 is not v1)
+        self.assertNotEqual(v1, v2)
+        v3 = CharType(10)
+        self.assertEqual(v1, v3)
+        self.assertFalse(v1 is v3)
+
     def test_varchar_type(self):
         v1 = VarcharType(10)
         v2 = VarcharType(20)
@@ -1221,14 +1235,18 @@ class DataTypeVerificationTests(unittest.TestCase):
         success_spec = [
             # String
             ("", StringType()),
-            ("", StringType()),
             (1, StringType()),
             (1.0, StringType()),
             ([], StringType()),
             ({}, StringType()),
+            # Char
+            ("", CharType(10)),
+            (1, CharType(10)),
+            (1.0, CharType(10)),
+            ([], CharType(10)),
+            ({}, CharType(10)),
             # Varchar
             ("", VarcharType(10)),
-            ("", VarcharType(10)),
             (1, VarcharType(10)),
             (1.0, VarcharType(10)),
             ([], VarcharType(10)),
@@ -1289,6 +1307,8 @@ class DataTypeVerificationTests(unittest.TestCase):
         failure_spec = [
             # String (match anything but None)
             (None, StringType(), ValueError),
+            # CharType (match anything but None)
+            (None, CharType(10), ValueError),
             # VarcharType (match anything but None)
             (None, VarcharType(10), ValueError),
             # UDT
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 7ab8f7c9c2d..e034ff75e10 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -56,6 +56,7 @@ U = TypeVar("U")
 __all__ = [
     "DataType",
     "NullType",
+    "CharType",
     "StringType",
     "VarcharType",
     "BinaryType",
@@ -182,6 +183,28 @@ class StringType(AtomicType, metaclass=DataTypeSingleton):
     pass
 
 
+class CharType(AtomicType):
+    """Char data type
+
+    Parameters
+    ----------
+    length : int
+        the length limitation.
+    """
+
+    def __init__(self, length: int):
+        self.length = length
+
+    def simpleString(self) -> str:
+        return "char(%d)" % (self.length)
+
+    def jsonValue(self) -> str:
+        return "char(%d)" % (self.length)
+
+    def __repr__(self) -> str:
+        return "CharType(%d)" % (self.length)
+
+
 class VarcharType(AtomicType):
     """Varchar data type
 
@@ -648,6 +671,10 @@ class StructType(DataType):
     >>> struct2 = StructType([StructField("f1", StringType(), True)])
     >>> struct1 == struct2
     True
+    >>> struct1 = StructType([StructField("f1", CharType(10), True)])
+    >>> struct2 = StructType([StructField("f1", CharType(10), True)])
+    >>> struct1 == struct2
+    True
     >>> struct1 = StructType([StructField("f1", VarcharType(10), True)])
     >>> struct2 = StructType([StructField("f1", VarcharType(10), True)])
     >>> struct1 == struct2
@@ -971,6 +998,7 @@ class UserDefinedType(DataType):
 
 _atomic_types: List[Type[DataType]] = [
     StringType,
+    CharType,
     VarcharType,
     BinaryType,
     BooleanType,
@@ -993,6 +1021,7 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, 
MapType, StructType]]] = dic
     (v.typeName(), v) for v in _complex_types
 )
 
+_LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
 _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
 _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to 
(day|hour|minute|second))?")
@@ -1015,6 +1044,8 @@ def _parse_datatype_string(s: str) -> DataType:
     StructType([StructField('a', ByteType(), True), StructField('b', 
DecimalType(16,8), True)])
     >>> _parse_datatype_string("a DOUBLE, b STRING")
     StructType([StructField('a', DoubleType(), True), StructField('b', 
StringType(), True)])
+    >>> _parse_datatype_string("a DOUBLE, b CHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', 
CharType(50), True)])
     >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )")
     StructType([StructField('a', DoubleType(), True), StructField('b', 
VarcharType(50), True)])
     >>> _parse_datatype_string("a: array< short>")
@@ -1085,7 +1116,7 @@ def _parse_datatype_json_string(json_string: str) -> 
DataType:
     ...     python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
     ...     assert datatype == python_datatype
     >>> for cls in _all_atomic_types.values():
-    ...     if cls is not VarcharType:
+    ...     if cls is not VarcharType and cls is not CharType:
     ...         check_datatype(cls())
     ...     else:
     ...         check_datatype(cls(1))
@@ -1112,6 +1143,7 @@ def _parse_datatype_json_string(json_string: str) -> 
DataType:
     ...     StructField("simpleMap", simple_maptype, True),
     ...     StructField("simpleStruct", simple_structtype, True),
     ...     StructField("boolean", BooleanType(), False),
+    ...     StructField("chars", CharType(10), False),
     ...     StructField("words", VarcharType(10), False),
     ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
     >>> check_datatype(complex_structtype)
@@ -1145,6 +1177,9 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             if first_field is not None and second_field is None:
                 return DayTimeIntervalType(first_field)
             return DayTimeIntervalType(first_field, second_field)
+        elif _LENGTH_CHAR.match(json_value):
+            m = _LENGTH_CHAR.match(json_value)
+            return CharType(int(m.group(1)))  # type: ignore[union-attr]
         elif _LENGTH_VARCHAR.match(json_value):
             m = _LENGTH_VARCHAR.match(json_value)
             return VarcharType(int(m.group(1)))  # type: ignore[union-attr]
@@ -1586,6 +1621,7 @@ _acceptable_types = {
     DoubleType: (float,),
     DecimalType: (decimal.Decimal,),
     StringType: (str,),
+    CharType: (str,),
     VarcharType: (str,),
     BinaryType: (bytearray, bytes),
     DateType: (datetime.date, datetime.datetime),
@@ -1697,8 +1733,8 @@ def _make_type_verifier(
                 new_msg("%s can not accept object %r in type %s" % (dataType, 
obj, type(obj)))
             )
 
-    if isinstance(dataType, (StringType, VarcharType)):
-        # StringType and VarcharType can work with any types
+    if isinstance(dataType, (StringType, CharType, VarcharType)):
+        # StringType, CharType and VarcharType can work with any types
         def verify_value(obj: Any) -> None:
             pass
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to