HyukjinKwon commented on code in PR #40276:
URL: https://github.com/apache/spark/pull/40276#discussion_r1125461718


##########
python/pyspark/sql/connect/types.py:
##########
@@ -342,20 +343,325 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> 
StructType:
 
 
 def parse_data_type(data_type: str) -> DataType:
-    # Currently we don't have a way to have a current Spark session in Spark 
Connect, and
-    # pyspark.sql.SparkSession has a centralized logic to control the session 
creation.
-    # So uses pyspark.sql.SparkSession for now. Should replace this to using 
the current
-    # Spark session for Spark Connect in the future.
-    from pyspark.sql import SparkSession as PySparkSession
-
-    assert is_remote()
-    return_type_schema = (
-        PySparkSession.builder.getOrCreate().createDataFrame(data=[], 
schema=data_type).schema
+    """
+    Parses the given data type string to a :class:`DataType`. The data type 
string format equals
+    :class:`DataType.simpleString`, except that the top level struct type can 
omit
+    the ``struct<>``. Since Spark 2.3, this also supports a schema in a 
DDL-formatted
+    string and case-insensitive strings.
+
+    Examples
+    --------
+    >>> parse_data_type("int ")
+    IntegerType()
+    >>> parse_data_type("INT ")
+    IntegerType()
+    >>> parse_data_type("a: byte, b: decimal(  16 , 8   ) ")
+    StructType([StructField('a', ByteType(), True), StructField('b', 
DecimalType(16,8), True)])
+    >>> parse_data_type("a DOUBLE, b STRING")
+    StructType([StructField('a', DoubleType(), True), StructField('b', 
StringType(), True)])
+    >>> parse_data_type("a DOUBLE, b CHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', 
CharType(50), True)])
+    >>> parse_data_type("a DOUBLE, b VARCHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', 
VarcharType(50), True)])
+    >>> parse_data_type("a: array< short>")
+    StructType([StructField('a', ArrayType(ShortType(), True), True)])
+    >>> parse_data_type(" map<string , string > ")
+    MapType(StringType(), StringType(), True)
+
+    >>> # Error cases
+    >>> parse_data_type("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    >>> parse_data_type("map<int, boolean>>") # doctest: 
+IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ParseException:...
+    """
+    try:
+        # DDL format, "fieldname datatype, fieldname datatype".
+        return DDLSchemaParser(data_type).from_ddl_schema()
+    except ParseException as e:
+        try:
+            # For backwards compatibility, "integer", "struct<fieldname: 
datatype>" and etc.
+            return DDLDataTypeParser(data_type).from_ddl_datatype()
+        except ParseException:
+            try:
+                # For backwards compatibility, "fieldname: datatype, 
fieldname: datatype" case.
+                return 
DDLDataTypeParser(f"struct<{data_type}>").from_ddl_datatype()
+            except ParseException:
+                raise e from None
+
+
+class DataTypeParserBase:
+    REGEXP_IDENTIFIER: Final[Pattern] = re.compile("\\w+|`(?:``|[^`])*`", 
re.MULTILINE)
+    REGEXP_INTEGER_VALUES: Final[Pattern] = re.compile(
+        "\\(\\s*(?:[+-]?\\d+)\\s*(?:,\\s*(?:[+-]?\\d+)\\s*)*\\)", re.MULTILINE
     )
-    with_col_name = " " in data_type.strip()
-    if len(return_type_schema.fields) == 1 and not with_col_name:
-        # To match pyspark.sql.types._parse_datatype_string
-        return_type = return_type_schema.fields[0].dataType
-    else:
-        return_type = return_type_schema
-    return return_type
+    REGEXP_INTERVAL_TYPE: Final[Pattern] = re.compile(
+        "(day|hour|minute|second)(?:\\s+to\\s+(hour|minute|second))?", 
re.IGNORECASE | re.MULTILINE
+    )
+    REGEXP_NOT_NULL_COMMENT: Final[Pattern] = re.compile(
+        "(not\\s+null)?(?:(?(1)\\s+)comment\\s+'((?:\\\\'|[^'])*)')?", 
re.IGNORECASE | re.MULTILINE
+    )
+
+    def __init__(self, type_str: str):
+        self._type_str = type_str
+        self._pos = 0
+        self._lstrip()
+
+    def _lstrip(self) -> None:
+        remaining = self._type_str[self._pos :]
+        self._pos = self._pos + (len(remaining) - len(remaining.lstrip()))
+
+    def _parse_data_type(self) -> DataType:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            data_type_name = m.group(0).lower().strip("`").replace("``", "`")
+            self._pos = self._pos + len(m.group(0))
+            self._lstrip()
+            if data_type_name == "array":
+                return self._parse_array_type()
+            elif data_type_name == "map":
+                return self._parse_map_type()
+            elif data_type_name == "struct":
+                return self._parse_struct_type()
+            elif data_type_name == "interval":
+                return self._parse_interval_type()
+            else:
+                return self._parse_primitive_types(data_type_name)
+
+        raise ParseException(
+            error_class="PARSE_SYNTAX_ERROR",
+            message_parameters={"error": f"'{type_str}'", "hint": ""},
+        )
+
+    def _parse_array_type(self) -> ArrayType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            element_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return ArrayType(element_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.ARRAY", 
message_parameters={})
+
+    def _parse_map_type(self) -> MapType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            key_type = self._parse_data_type()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ",":
+                self._pos = self._pos + 1
+                self._lstrip()
+                value_type = self._parse_data_type()
+                remaining = self._type_str[self._pos :]
+                if len(remaining) > 0 and remaining[0] == ">":
+                    self._pos = self._pos + 1
+                    self._lstrip()
+                    return MapType(key_type, value_type)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.MAP", 
message_parameters={})
+
+    def _parse_struct_type(self) -> StructType:
+        type_str = self._type_str[self._pos :]
+        if len(type_str) > 0 and type_str[0] == "<":
+            self._pos = self._pos + 1
+            self._lstrip()
+            fields = self._parse_struct_fields()
+            remaining = self._type_str[self._pos :]
+            if len(remaining) > 0 and remaining[0] == ">":
+                self._pos = self._pos + 1
+                self._lstrip()
+                return StructType(fields)
+        raise ParseException(error_class="INCOMPLETE_TYPE_DEFINITION.STRUCT", 
message_parameters={})
+
+    def _parse_struct_fields(self, sep_with_colon: bool = True) -> 
List[StructField]:
+        type_str = self._type_str[self._pos :]
+        m = self.REGEXP_IDENTIFIER.match(type_str)
+        if m:
+            field_name = m.group(0).lower().strip("`").replace("``", "`")

Review Comment:
   Couple of concerns..
   
   I still doubt if this is something we should manually implement ... for 
example, what about case sensitivity .. To do this properly, we should use 
antlr for Python (e.g., 
https://github.com/antlr/antlr4/blob/master/doc/python-target.md). But I feel 
that's sort of overkill.
   
   We don't have new types very often but still we add. For example, in the 
last couple of years, we added ANSI types, date, year, month interval, 
timestamp without timzone, etc. We also expose the char and varchar type that 
was hidden before.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to