This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f532d222321 [SPARK-42954][PYTHON][CONNECT] Add `YearMonthIntervalType` 
to PySpark and Spark Connect Python Client
f532d222321 is described below

commit f532d222321aeec6d736ccf69f71b94fe07d4cd8
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Mar 29 17:39:54 2023 +0800

    [SPARK-42954][PYTHON][CONNECT] Add `YearMonthIntervalType` to PySpark and 
Spark Connect Python Client
    
    ### What changes were proposed in this pull request?
    Add `YearMonthIntervalType` to PySpark and Spark Connect Python Client
    
    ### Why are the changes needed?
    function parity
    
    **Note**
    the added  `YearMonthIntervalType` is not supported in 
`collect`/`createDataFrame`, since I cannot find a python built-in type for 
`YearMonthIntervalType` (like `datetime.timedelta` for `DayTimeIntervalType`), 
we need further discussion.
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new data type in python
    
    before this PR
    ```
    In [1]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval")
    Out[1]: 
---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    File ~/Dev/spark/python/pyspark/sql/dataframe.py:570, in 
DataFrame.schema(self)
        568 try:
        569     self._schema = cast(
    --> 570         StructType, 
_parse_datatype_json_string(self._jdf.schema().json())
        571     )
        572 except Exception as e:
    
    ...
    
    ValueError: Unable to parse datatype from schema. Could not parse datatype: 
interval year to month
    ```
    
    after this PR
    ```
    In [3]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval")
    Out[3]: DataFrame[interval: interval year to month]
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #40582 from zhengruifeng/py_y_m.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.sql/data_types.rst    |  1 +
 python/pyspark/sql/connect/types.py                | 16 +++++
 python/pyspark/sql/tests/test_types.py             | 32 ++++++++++
 python/pyspark/sql/types.py                        | 72 ++++++++++++++++++++--
 4 files changed, 116 insertions(+), 5 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/data_types.rst 
b/python/docs/source/reference/pyspark.sql/data_types.rst
index 53417e43419..60c6b92590d 100644
--- a/python/docs/source/reference/pyspark.sql/data_types.rst
+++ b/python/docs/source/reference/pyspark.sql/data_types.rst
@@ -47,3 +47,4 @@ Data Types
     TimestampType
     TimestampNTZType
     DayTimeIntervalType
+    YearMonthIntervalType
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index dfb0fb5303f..3afac8dc5b9 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -34,6 +34,7 @@ from pyspark.sql.types import (
     TimestampType,
     TimestampNTZType,
     DayTimeIntervalType,
+    YearMonthIntervalType,
     MapType,
     StringType,
     CharType,
@@ -154,6 +155,9 @@ def pyspark_types_to_proto_types(data_type: DataType) -> 
pb2.DataType:
     elif isinstance(data_type, DayTimeIntervalType):
         ret.day_time_interval.start_field = data_type.startField
         ret.day_time_interval.end_field = data_type.endField
+    elif isinstance(data_type, YearMonthIntervalType):
+        ret.year_month_interval.start_field = data_type.startField
+        ret.year_month_interval.end_field = data_type.endField
     elif isinstance(data_type, StructType):
         for field in data_type.fields:
             struct_field = pb2.DataType.StructField()
@@ -236,6 +240,18 @@ def proto_schema_to_pyspark_data_type(schema: 
pb2.DataType) -> DataType:
             else None
         )
         return DayTimeIntervalType(startField=start, endField=end)
+    elif schema.HasField("year_month_interval"):
+        start: Optional[int] = (  # type: ignore[no-redef]
+            schema.year_month_interval.start_field
+            if schema.year_month_interval.HasField("start_field")
+            else None
+        )
+        end: Optional[int] = (  # type: ignore[no-redef]
+            schema.year_month_interval.end_field
+            if schema.year_month_interval.HasField("end_field")
+            else None
+        )
+        return YearMonthIntervalType(startField=start, endField=end)
     elif schema.HasField("array"):
         return ArrayType(
             proto_schema_to_pyspark_data_type(schema.array.element_type),
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 5d6476b47f4..dd2abda4620 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -35,6 +35,7 @@ from pyspark.sql.types import (
     DateType,
     TimestampType,
     DayTimeIntervalType,
+    YearMonthIntervalType,
     MapType,
     StringType,
     CharType,
@@ -1190,6 +1191,37 @@ class TypesTestsMixin:
         for n, (a, e) in enumerate(zip(actual, expected)):
             self.assertEqual(a, e, "%s does not match with %s" % (exprs[n], 
expected[n]))
 
+    def test_yearmonth_interval_type_constructor(self):
+        self.assertEqual(YearMonthIntervalType().simpleString(), "interval 
year to month")
+        self.assertEqual(
+            YearMonthIntervalType(YearMonthIntervalType.YEAR).simpleString(), 
"interval year"
+        )
+        self.assertEqual(
+            YearMonthIntervalType(
+                YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH
+            ).simpleString(),
+            "interval year to month",
+        )
+
+        with self.assertRaisesRegex(RuntimeError, "interval None to 3 is 
invalid"):
+            YearMonthIntervalType(endField=3)
+
+        with self.assertRaisesRegex(RuntimeError, "interval 123 to 123 is 
invalid"):
+            YearMonthIntervalType(123)
+
+        with self.assertRaisesRegex(RuntimeError, "interval 0 to 321 is 
invalid"):
+            YearMonthIntervalType(YearMonthIntervalType.YEAR, 321)
+
+    def test_yearmonth_interval_type(self):
+        schema1 = self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").schema
+        self.assertEqual(schema1.fields[0].dataType, YearMonthIntervalType(0, 
1))
+
+        schema2 = self.spark.sql("SELECT INTERVAL '10' YEAR AS 
interval").schema
+        self.assertEqual(schema2.fields[0].dataType, YearMonthIntervalType(0, 
0))
+
+        schema3 = self.spark.sql("SELECT INTERVAL '8' MONTH AS 
interval").schema
+        self.assertEqual(schema3.fields[0].dataType, YearMonthIntervalType(1, 
1))
+
 
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ff43e4b00e9..721be76e8ba 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -74,6 +74,7 @@ __all__ = [
     "IntegerType",
     "LongType",
     "DayTimeIntervalType",
+    "YearMonthIntervalType",
     "Row",
     "ShortType",
     "ArrayType",
@@ -374,7 +375,20 @@ class LongType(IntegralType):
         return "bigint"
 
 
-class DayTimeIntervalType(AtomicType):
+class ShortType(IntegralType):
+    """Short data type, i.e. a signed 16-bit integer."""
+
+    def simpleString(self) -> str:
+        return "smallint"
+
+
+class AnsiIntervalType(AtomicType):
+    """The interval type which conforms to the ANSI SQL standard."""
+
+    pass
+
+
+class DayTimeIntervalType(AnsiIntervalType):
     """DayTimeIntervalType (datetime.timedelta)."""
 
     DAY = 0
@@ -433,11 +447,48 @@ class DayTimeIntervalType(AtomicType):
             return datetime.timedelta(microseconds=micros)
 
 
-class ShortType(IntegralType):
-    """Short data type, i.e. a signed 16-bit integer."""
+class YearMonthIntervalType(AnsiIntervalType):
+    """YearMonthIntervalType, represents year-month intervals of the SQL 
standard"""
 
-    def simpleString(self) -> str:
-        return "smallint"
+    YEAR = 0
+    MONTH = 1
+
+    _fields = {
+        YEAR: "year",
+        MONTH: "month",
+    }
+
+    _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
+
+    def __init__(self, startField: Optional[int] = None, endField: 
Optional[int] = None):
+        if startField is None and endField is None:
+            # Default matched to scala side.
+            startField = YearMonthIntervalType.YEAR
+            endField = YearMonthIntervalType.MONTH
+        elif startField is not None and endField is None:
+            endField = startField
+
+        fields = YearMonthIntervalType._fields
+        if startField not in fields.keys() or endField not in fields.keys():
+            raise RuntimeError("interval %s to %s is invalid" % (startField, 
endField))
+        self.startField = cast(int, startField)
+        self.endField = cast(int, endField)
+
+    def _str_repr(self) -> str:
+        fields = YearMonthIntervalType._fields
+        start_field_name = fields[self.startField]
+        end_field_name = fields[self.endField]
+        if start_field_name == end_field_name:
+            return "interval %s" % start_field_name
+        else:
+            return "interval %s to %s" % (start_field_name, end_field_name)
+
+    simpleString = _str_repr
+
+    jsonValue = _str_repr
+
+    def __repr__(self) -> str:
+        return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
 
 class ArrayType(DataType):
@@ -1162,6 +1213,7 @@ _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
 _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
 _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to 
(day|hour|minute|second))?")
+_INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
 
 
 def _parse_datatype_string(s: str) -> DataType:
@@ -1311,6 +1363,14 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             if first_field is not None and second_field is None:
                 return DayTimeIntervalType(first_field)
             return DayTimeIntervalType(first_field, second_field)
+        elif _INTERVAL_YEARMONTH.match(json_value):
+            m = _INTERVAL_YEARMONTH.match(json_value)
+            inverted_fields = YearMonthIntervalType._inverted_fields
+            first_field = inverted_fields.get(m.group(1))  # type: 
ignore[union-attr]
+            second_field = inverted_fields.get(m.group(3))  # type: 
ignore[union-attr]
+            if first_field is not None and second_field is None:
+                return YearMonthIntervalType(first_field)
+            return YearMonthIntervalType(first_field, second_field)
         elif _LENGTH_CHAR.match(json_value):
             m = _LENGTH_CHAR.match(json_value)
             return CharType(int(m.group(1)))  # type: ignore[union-attr]
@@ -1465,6 +1525,8 @@ def _infer_type(
         return TimestampNTZType()
     if dataType is DayTimeIntervalType:
         return DayTimeIntervalType()
+    if dataType is YearMonthIntervalType:
+        return YearMonthIntervalType()
     elif dataType is not None:
         return dataType()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to