This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f532d222321 [SPARK-42954][PYTHON][CONNECT] Add `YearMonthIntervalType` to PySpark and Spark Connect Python Client f532d222321 is described below commit f532d222321aeec6d736ccf69f71b94fe07d4cd8 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Mar 29 17:39:54 2023 +0800 [SPARK-42954][PYTHON][CONNECT] Add `YearMonthIntervalType` to PySpark and Spark Connect Python Client ### What changes were proposed in this pull request? Add `YearMonthIntervalType` to PySpark and Spark Connect Python Client ### Why are the changes needed? function parity **Note** the added `YearMonthIntervalType` is not supported in `collect`/`createDataFrame`, since I cannot find a python built-in type for `YearMonthIntervalType` (like `datetime.timedelta` for `DayTimeIntervalType`), we need further discussion. ### Does this PR introduce _any_ user-facing change? yes, new data type in python before this PR ``` In [1]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval") Out[1]: --------------------------------------------------------------------------- ValueError Traceback (most recent call last) File ~/Dev/spark/python/pyspark/sql/dataframe.py:570, in DataFrame.schema(self) 568 try: 569 self._schema = cast( --> 570 StructType, _parse_datatype_json_string(self._jdf.schema().json()) 571 ) 572 except Exception as e: ... ValueError: Unable to parse datatype from schema. Could not parse datatype: interval year to month ``` after this PR ``` In [3]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval") Out[3]: DataFrame[interval: interval year to month] ``` ### How was this patch tested? added UT Closes #40582 from zhengruifeng/py_y_m. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.sql/data_types.rst | 1 + python/pyspark/sql/connect/types.py | 16 +++++ python/pyspark/sql/tests/test_types.py | 32 ++++++++++ python/pyspark/sql/types.py | 72 ++++++++++++++++++++-- 4 files changed, 116 insertions(+), 5 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/data_types.rst b/python/docs/source/reference/pyspark.sql/data_types.rst index 53417e43419..60c6b92590d 100644 --- a/python/docs/source/reference/pyspark.sql/data_types.rst +++ b/python/docs/source/reference/pyspark.sql/data_types.rst @@ -47,3 +47,4 @@ Data Types TimestampType TimestampNTZType DayTimeIntervalType + YearMonthIntervalType diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index dfb0fb5303f..3afac8dc5b9 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -34,6 +34,7 @@ from pyspark.sql.types import ( TimestampType, TimestampNTZType, DayTimeIntervalType, + YearMonthIntervalType, MapType, StringType, CharType, @@ -154,6 +155,9 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: elif isinstance(data_type, DayTimeIntervalType): ret.day_time_interval.start_field = data_type.startField ret.day_time_interval.end_field = data_type.endField + elif isinstance(data_type, YearMonthIntervalType): + ret.year_month_interval.start_field = data_type.startField + ret.year_month_interval.end_field = data_type.endField elif isinstance(data_type, StructType): for field in data_type.fields: struct_field = pb2.DataType.StructField() @@ -236,6 +240,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: else None ) return DayTimeIntervalType(startField=start, endField=end) + elif schema.HasField("year_month_interval"): + start: Optional[int] = ( # type: ignore[no-redef] + schema.year_month_interval.start_field + if schema.year_month_interval.HasField("start_field") + else None + ) + end: Optional[int] = ( # type: ignore[no-redef] + schema.year_month_interval.end_field + if schema.year_month_interval.HasField("end_field") + else None + ) + return YearMonthIntervalType(startField=start, endField=end) elif schema.HasField("array"): return ArrayType( proto_schema_to_pyspark_data_type(schema.array.element_type), diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 5d6476b47f4..dd2abda4620 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -35,6 +35,7 @@ from pyspark.sql.types import ( DateType, TimestampType, DayTimeIntervalType, + YearMonthIntervalType, MapType, StringType, CharType, @@ -1190,6 +1191,37 @@ class TypesTestsMixin: for n, (a, e) in enumerate(zip(actual, expected)): self.assertEqual(a, e, "%s does not match with %s" % (exprs[n], expected[n])) + def test_yearmonth_interval_type_constructor(self): + self.assertEqual(YearMonthIntervalType().simpleString(), "interval year to month") + self.assertEqual( + YearMonthIntervalType(YearMonthIntervalType.YEAR).simpleString(), "interval year" + ) + self.assertEqual( + YearMonthIntervalType( + YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH + ).simpleString(), + "interval year to month", + ) + + with self.assertRaisesRegex(RuntimeError, "interval None to 3 is invalid"): + YearMonthIntervalType(endField=3) + + with self.assertRaisesRegex(RuntimeError, "interval 123 to 123 is invalid"): + YearMonthIntervalType(123) + + with self.assertRaisesRegex(RuntimeError, "interval 0 to 321 is invalid"): + YearMonthIntervalType(YearMonthIntervalType.YEAR, 321) + + def test_yearmonth_interval_type(self): + schema1 = self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval").schema + self.assertEqual(schema1.fields[0].dataType, YearMonthIntervalType(0, 1)) + + schema2 = self.spark.sql("SELECT INTERVAL '10' YEAR AS interval").schema + self.assertEqual(schema2.fields[0].dataType, YearMonthIntervalType(0, 0)) + + schema3 = self.spark.sql("SELECT INTERVAL '8' MONTH AS interval").schema + self.assertEqual(schema3.fields[0].dataType, YearMonthIntervalType(1, 1)) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ff43e4b00e9..721be76e8ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -74,6 +74,7 @@ __all__ = [ "IntegerType", "LongType", "DayTimeIntervalType", + "YearMonthIntervalType", "Row", "ShortType", "ArrayType", @@ -374,7 +375,20 @@ class LongType(IntegralType): return "bigint" -class DayTimeIntervalType(AtomicType): +class ShortType(IntegralType): + """Short data type, i.e. a signed 16-bit integer.""" + + def simpleString(self) -> str: + return "smallint" + + +class AnsiIntervalType(AtomicType): + """The interval type which conforms to the ANSI SQL standard.""" + + pass + + +class DayTimeIntervalType(AnsiIntervalType): """DayTimeIntervalType (datetime.timedelta).""" DAY = 0 @@ -433,11 +447,48 @@ class DayTimeIntervalType(AtomicType): return datetime.timedelta(microseconds=micros) -class ShortType(IntegralType): - """Short data type, i.e. a signed 16-bit integer.""" +class YearMonthIntervalType(AnsiIntervalType): + """YearMonthIntervalType, represents year-month intervals of the SQL standard""" - def simpleString(self) -> str: - return "smallint" + YEAR = 0 + MONTH = 1 + + _fields = { + YEAR: "year", + MONTH: "month", + } + + _inverted_fields = dict(zip(_fields.values(), _fields.keys())) + + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + if startField is None and endField is None: + # Default matched to scala side. + startField = YearMonthIntervalType.YEAR + endField = YearMonthIntervalType.MONTH + elif startField is not None and endField is None: + endField = startField + + fields = YearMonthIntervalType._fields + if startField not in fields.keys() or endField not in fields.keys(): + raise RuntimeError("interval %s to %s is invalid" % (startField, endField)) + self.startField = cast(int, startField) + self.endField = cast(int, endField) + + def _str_repr(self) -> str: + fields = YearMonthIntervalType._fields + start_field_name = fields[self.startField] + end_field_name = fields[self.endField] + if start_field_name == end_field_name: + return "interval %s" % start_field_name + else: + return "interval %s to %s" % (start_field_name, end_field_name) + + simpleString = _str_repr + + jsonValue = _str_repr + + def __repr__(self) -> str: + return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) class ArrayType(DataType): @@ -1162,6 +1213,7 @@ _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") +_INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") def _parse_datatype_string(s: str) -> DataType: @@ -1311,6 +1363,14 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: if first_field is not None and second_field is None: return DayTimeIntervalType(first_field) return DayTimeIntervalType(first_field, second_field) + elif _INTERVAL_YEARMONTH.match(json_value): + m = _INTERVAL_YEARMONTH.match(json_value) + inverted_fields = YearMonthIntervalType._inverted_fields + first_field = inverted_fields.get(m.group(1)) # type: ignore[union-attr] + second_field = inverted_fields.get(m.group(3)) # type: ignore[union-attr] + if first_field is not None and second_field is None: + return YearMonthIntervalType(first_field) + return YearMonthIntervalType(first_field, second_field) elif _LENGTH_CHAR.match(json_value): m = _LENGTH_CHAR.match(json_value) return CharType(int(m.group(1))) # type: ignore[union-attr] @@ -1465,6 +1525,8 @@ def _infer_type( return TimestampNTZType() if dataType is DayTimeIntervalType: return DayTimeIntervalType() + if dataType is YearMonthIntervalType: + return YearMonthIntervalType() elif dataType is not None: return dataType() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org