HyukjinKwon commented on code in PR #40276: URL: https://github.com/apache/spark/pull/40276#discussion_r1125461718
########## python/pyspark/sql/connect/types.py: ########## @@ -342,20 +343,325 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType: def parse_data_type(data_type: str) -> DataType: - # Currently we don't have a way to have a current Spark session in Spark Connect, and - # pyspark.sql.SparkSession has a centralized logic to control the session creation. - # So uses pyspark.sql.SparkSession for now. Should replace this to using the current - # Spark session for Spark Connect in the future. - from pyspark.sql import SparkSession as PySparkSession - - assert is_remote() - return_type_schema = ( - PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema + """ + Parses the given data type string to a :class:`DataType`. The data type string format equals + :class:`DataType.simpleString`, except that the top level struct type can omit + the ``struct<>``. Since Spark 2.3, this also supports a schema in a DDL-formatted + string and case-insensitive strings. + + Examples + -------- + >>> parse_data_type("int ") + IntegerType() + >>> parse_data_type("INT ") + IntegerType() + >>> parse_data_type("a: byte, b: decimal( 16 , 8 ) ") + StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)]) + >>> parse_data_type("a DOUBLE, b STRING") + StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)]) + >>> parse_data_type("a DOUBLE, b CHAR( 50 )") + StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)]) + >>> parse_data_type("a DOUBLE, b VARCHAR( 50 )") + StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)]) + >>> parse_data_type("a: array< short>") + StructType([StructField('a', ArrayType(ShortType(), True), True)]) + >>> parse_data_type(" map<string , string > ") + MapType(StringType(), StringType(), True) + + >>> # Error cases + >>> parse_data_type("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ParseException:... + >>> parse_data_type("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ParseException:... + >>> parse_data_type("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ParseException:... + >>> parse_data_type("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ParseException:... + """ + try: + # DDL format, "fieldname datatype, fieldname datatype". + return DDLSchemaParser(data_type).from_ddl_schema() + except ParseException as e: + try: + # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc. + return DDLDataTypeParser(data_type).from_ddl_datatype() + except ParseException: + try: + # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. + return DDLDataTypeParser(f"struct<{data_type}>").from_ddl_datatype() + except ParseException: + raise e from None + + +class DataTypeParserBase: + REGEXP_IDENTIFIER: Final[Pattern] = re.compile("\\w+|`(?:``|[^`])*`", re.MULTILINE) + REGEXP_INTEGER_VALUES: Final[Pattern] = re.compile( + "\\(\\s*(?:[+-]?\\d+)\\s*(?:,\\s*(?:[+-]?\\d+)\\s*)*\\)", re.MULTILINE ) - with_col_name = " " in data_type.strip() - if len(return_type_schema.fields) == 1 and not with_col_name: - # To match pyspark.sql.types._parse_datatype_string - return_type = return_type_schema.fields[0].dataType - else: - return_type = return_type_schema - return return_type + REGEXP_INTERVAL_TYPE: Final[Pattern] = re.compile( + "(day|hour|minute|second)(?:\\s+to\\s+(hour|minute|second))?", re.IGNORECASE | re.MULTILINE + ) + REGEXP_NOT_NULL_COMMENT: Final[Pattern] = re.compile( + "(not\\s+null)?(?:(?(1)\\s+)comment\\s+'((?:\\\\'|[^'])*)')?", re.IGNORECASE | re.MULTILINE + ) + + def __init__(self, type_str: str): + self._type_str = type_str + self._pos = 0 + self._lstrip() + + def _lstrip(self) -> None: + remaining = self._type_str[self._pos :] + self._pos = self._pos + (len(remaining) - len(remaining.lstrip())) + + def _parse_data_type(self) -> DataType: + type_str = self._type_str[self._pos :] + m = self.REGEXP_IDENTIFIER.match(type_str) + if m: + data_type_name = m.group(0).lower().strip("`").replace("``", "`") + self._pos = self._pos + len(m.group(0)) + self._lstrip() + if data_type_name == "array": + return self._parse_array_type() + elif data_type_name == "map": + return self._parse_map_type() + elif data_type_name == "struct": + return self._parse_struct_type() + elif data_type_name == "interval": + return self._parse_interval_type() + else: + return self._parse_primitive_types(data_type_name) + + raise ParseException( + error_class="PARSE_SYNTAX_ERROR", + message_parameters={"error": f"'{type_str}'", "hint": ""}, + ) + + def _parse_array_type(self) -> ArrayType: + type_str = self._type_str[self._pos :] + if len(type_str) > 0 and type_str[0] == "<": + self._pos = self._pos + 1 + self._lstrip() + element_type = self._parse_data_type() + remaining = self._type_str[self._pos :] + if len(remaining) and remaining[0] == ">": + self._pos = self._pos + 1 + self._lstrip() + return ArrayType(element_type) + raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.ARRAY", message_parameters={}) + + def _parse_map_type(self) -> MapType: + type_str = self._type_str[self._pos :] + if len(type_str) > 0 and type_str[0] == "<": + self._pos = self._pos + 1 + self._lstrip() + key_type = self._parse_data_type() + remaining = self._type_str[self._pos :] + if len(remaining) > 0 and remaining[0] == ",": + self._pos = self._pos + 1 + self._lstrip() + value_type = self._parse_data_type() + remaining = self._type_str[self._pos :] + if len(remaining) > 0 and remaining[0] == ">": + self._pos = self._pos + 1 + self._lstrip() + return MapType(key_type, value_type) + raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.MAP", message_parameters={}) + + def _parse_struct_type(self) -> StructType: + type_str = self._type_str[self._pos :] + if len(type_str) > 0 and type_str[0] == "<": + self._pos = self._pos + 1 + self._lstrip() + fields = self._parse_struct_fields() + remaining = self._type_str[self._pos :] + if len(remaining) > 0 and remaining[0] == ">": + self._pos = self._pos + 1 + self._lstrip() + return StructType(fields) + raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.STRUCT", message_parameters={}) + + def _parse_struct_fields(self, sep_with_colon: bool = True) -> List[StructField]: + type_str = self._type_str[self._pos :] + m = self.REGEXP_IDENTIFIER.match(type_str) + if m: + field_name = m.group(0).lower().strip("`").replace("``", "`") Review Comment: Couple of concerns.. I still doubt if this is something we should manually implement ... for example, what about case sensitivity .. To do this properly, we should use antlr for Python (e.g., https://github.com/antlr/antlr4/blob/master/doc/python-target.md). But I feel that's sort of overkill. We don't have new types very often but still we add. For example, in the last couple of years, we added ANSI types, date, year, month interval, timestamp without timzone, etc. We also expose the char and varchar type that was hidden before. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org