This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a46296e74408 [SPARK-52706][SQL] Fix inconsistencies and refactor 
primitive types in parser
a46296e74408 is described below

commit a46296e74408626d1e462ff0c68c27b3fa0b4d47
Author: Mihailo Milosevic <[email protected]>
AuthorDate: Wed Jul 9 08:37:22 2025 +0800

    [SPARK-52706][SQL] Fix inconsistencies and refactor primitive types in 
parser
    
    ### What changes were proposed in this pull request?
    This PR proposes a change in how our parser treats datatypes. We introduce 
types with/without parameters and group accordingly.
    
    ### Why are the changes needed?
    Changes are needed for many reasons:
    1. Context of primitiveDataType is constantly getting bigger. This is not a 
good practice, as we have many null fields which only take up memory.
    2. We have inconsistencies in where we use each type. We get TIMESTAMP_NTZ 
in a separate rule, but we also mention it in primitive types.
    3. Primitive types should stay related to primitive types, adding ARRAY, 
STRUCT, MAP in the rule just because it is convenient is not good practice.
    4. Current structure does not give option of extending types with different 
features. For example, we introduced STRING collations, but what if we were to 
introduce CHAR/VARCHAR with collations. Current structure gives us 0 
possibility of making a type CHAR(5) COLLATE UTF8_BINARY (We can only do CHAR 
COLLATE UTF8_BINARY (5)).
    
    ### Does this PR introduce _any_ user-facing change?
    No. This is internal refactoring.
    
    ### How was this patch tested?
    All existing tests should pass, this is just code refactoring.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51335 from mihailom-db/restructure-primitive.
    
    Authored-by: Mihailo Milosevic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |  46 ++++---
 .../sql/catalyst/parser/DataTypeAstBuilder.scala   | 144 ++++++++++++---------
 .../spark/sql/errors/QueryParsingErrors.scala      |   6 +-
 3 files changed, 110 insertions(+), 86 deletions(-)

diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index fc3d86ca858f..698afa486002 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -1340,7 +1340,20 @@ collateClause
     : COLLATE collationName=multipartIdentifier
     ;
 
-type
+nonTrivialPrimitiveType
+    : STRING collateClause?
+    | (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
+    | VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
+    | (DECIMAL | DEC | NUMERIC)
+        (LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? 
RIGHT_PAREN)?
+    | INTERVAL
+        (fromYearMonth=(YEAR | MONTH) (TO to=MONTH)? |
+         fromDayTime=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | 
SECOND))?)?
+    | TIMESTAMP (WITHOUT TIME ZONE)?
+    | TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? (WITHOUT TIME 
ZONE)?
+    ;
+
+trivialPrimitiveType
     : BOOLEAN
     | TINYINT | BYTE
     | SMALLINT | SHORT
@@ -1349,32 +1362,23 @@ type
     | FLOAT | REAL
     | DOUBLE
     | DATE
-    | TIME
-    | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
-    | STRING collateClause?
-    | CHARACTER | CHAR
-    | VARCHAR
+    | TIMESTAMP_LTZ | TIMESTAMP_NTZ
     | BINARY
-    | DECIMAL | DEC | NUMERIC
     | VOID
-    | INTERVAL
     | VARIANT
-    | ARRAY | STRUCT | MAP
-    | unsupportedType=identifier
+    ;
+
+primitiveType
+    : nonTrivialPrimitiveType
+    | trivialPrimitiveType
+    | unsupportedType=identifier (LEFT_PAREN INTEGER_VALUE(COMMA 
INTEGER_VALUE)* RIGHT_PAREN)?
     ;
 
 dataType
-    : complex=ARRAY LT dataType GT                              
#complexDataType
-    | complex=MAP LT dataType COMMA dataType GT                 
#complexDataType
-    | complex=STRUCT (LT complexColTypeList? GT | NEQ)          
#complexDataType
-    | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)?               
#yearMonthIntervalDataType
-    | INTERVAL from=(DAY | HOUR | MINUTE | SECOND)
-      (TO to=(HOUR | MINUTE | SECOND))?                         
#dayTimeIntervalDataType
-    | TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)?
-      (WITHOUT TIME ZONE)?                                      #timeDataType
-    | (TIMESTAMP_NTZ | TIMESTAMP WITHOUT TIME ZONE)             
#timestampNtzDataType
-    | type (LEFT_PAREN INTEGER_VALUE
-      (COMMA INTEGER_VALUE)* RIGHT_PAREN)?                      
#primitiveDataType
+    : complex=ARRAY (LT dataType GT)?                           
#complexDataType
+    | complex=MAP (LT dataType COMMA dataType GT)?              
#complexDataType
+    | complex=STRUCT ((LT complexColTypeList? GT) | NEQ)?       
#complexDataType
+    | primitiveType                                             
#primitiveDataType
     ;
 
 qualifiedColTypeWithPositionList
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index e83a987263db..beb7061a841a 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -65,74 +65,89 @@ class DataTypeAstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] {
       ctx.parts.asScala.map(_.getText).toSeq
     }
 
-  /**
-   * Resolve/create the TIME primitive type.
-   */
-  override def visitTimeDataType(ctx: TimeDataTypeContext): DataType = 
withOrigin(ctx) {
-    val precision = if (ctx.precision == null) {
-      TimeType.DEFAULT_PRECISION
-    } else {
-      ctx.precision.getText.toInt
-    }
-    TimeType(precision)
-  }
-
-  /**
-   * Create the TIMESTAMP_NTZ primitive type.
-   */
-  override def visitTimestampNtzDataType(ctx: TimestampNtzDataTypeContext): 
DataType = {
-    withOrigin(ctx)(TimestampNTZType)
-  }
-
   /**
    * Resolve/create a primitive type.
    */
   override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType 
= withOrigin(ctx) {
-    val typeCtx = ctx.`type`
-    (typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match {
-      case (BOOLEAN, Nil) => BooleanType
-      case (TINYINT | BYTE, Nil) => ByteType
-      case (SMALLINT | SHORT, Nil) => ShortType
-      case (INT | INTEGER, Nil) => IntegerType
-      case (BIGINT | LONG, Nil) => LongType
-      case (FLOAT | REAL, Nil) => FloatType
-      case (DOUBLE, Nil) => DoubleType
-      case (DATE, Nil) => DateType
-      case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
-      case (TIMESTAMP_LTZ, Nil) => TimestampType
-      case (STRING, Nil) =>
-        typeCtx.children.asScala.toSeq match {
-          case Seq(_) => StringType
-          case Seq(_, ctx: CollateClauseContext) =>
-            val collationNameParts = visitCollateClause(ctx).toArray
-            val collationId = CollationFactory.collationNameToId(
-              CollationFactory.resolveFullyQualifiedName(collationNameParts))
-            StringType(collationId)
-        }
-      case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
-      case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
-      case (BINARY, Nil) => BinaryType
-      case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT
-      case (DECIMAL | DEC | NUMERIC, precision :: Nil) =>
-        DecimalType(precision.getText.toInt, 0)
-      case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) =>
-        DecimalType(precision.getText.toInt, scale.getText.toInt)
-      case (VOID, Nil) => NullType
-      case (INTERVAL, Nil) => CalendarIntervalType
-      case (VARIANT, Nil) => VariantType
-      case (CHARACTER | CHAR | VARCHAR, Nil) =>
-        throw 
QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
-      case (ARRAY | STRUCT | MAP, Nil) =>
-        throw 
QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
-      case (_, params) =>
-        val badType = ctx.`type`.getText
-        val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})" 
else badType
-        throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
+    val typeCtx = ctx.primitiveType
+    if (typeCtx.nonTrivialPrimitiveType != null) {
+      // This is a primitive type with parameters, e.g. VARCHAR(10), 
DECIMAL(10, 2), etc.
+      val currentCtx = typeCtx.nonTrivialPrimitiveType
+      currentCtx.start.getType match {
+        case STRING =>
+          currentCtx.children.asScala.toSeq match {
+            case Seq(_) => StringType
+            case Seq(_, ctx: CollateClauseContext) =>
+              val collationNameParts = visitCollateClause(ctx).toArray
+              val collationId = CollationFactory.collationNameToId(
+                CollationFactory.resolveFullyQualifiedName(collationNameParts))
+              StringType(collationId)
+          }
+        case CHARACTER | CHAR =>
+          if (currentCtx.length == null) {
+            throw 
QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
+          } else CharType(currentCtx.length.getText.toInt)
+        case VARCHAR =>
+          if (currentCtx.length == null) {
+            throw 
QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
+          } else VarcharType(currentCtx.length.getText.toInt)
+        case DECIMAL | DEC | NUMERIC =>
+          if (currentCtx.precision == null) {
+            DecimalType.USER_DEFAULT
+          } else if (currentCtx.scale == null) {
+            DecimalType(currentCtx.precision.getText.toInt, 0)
+          } else {
+            DecimalType(currentCtx.precision.getText.toInt, 
currentCtx.scale.getText.toInt)
+          }
+        case INTERVAL =>
+          if (currentCtx.fromDayTime != null) {
+            visitDayTimeIntervalDataType(currentCtx)
+          } else if (currentCtx.fromYearMonth != null) {
+            visitYearMonthIntervalDataType(currentCtx)
+          } else {
+            CalendarIntervalType
+          }
+        case TIMESTAMP =>
+          if (currentCtx.WITHOUT() == null) {
+            SqlApiConf.get.timestampType
+          } else TimestampNTZType
+        case TIME =>
+          val precision = if (currentCtx.precision == null) {
+            TimeType.DEFAULT_PRECISION
+          } else {
+            currentCtx.precision.getText.toInt
+          }
+          TimeType(precision)
+      }
+    } else if (typeCtx.trivialPrimitiveType != null) {
+      // This is a primitive type without parameters, e.g. BOOLEAN, TINYINT, 
etc.
+      typeCtx.trivialPrimitiveType.start.getType match {
+        case BOOLEAN => BooleanType
+        case TINYINT | BYTE => ByteType
+        case SMALLINT | SHORT => ShortType
+        case INT | INTEGER => IntegerType
+        case BIGINT | LONG => LongType
+        case FLOAT | REAL => FloatType
+        case DOUBLE => DoubleType
+        case DATE => DateType
+        case TIMESTAMP_LTZ => TimestampType
+        case TIMESTAMP_NTZ => TimestampNTZType
+        case BINARY => BinaryType
+        case VOID => NullType
+        case VARIANT => VariantType
+      }
+    } else {
+      val badType = typeCtx.unsupportedType.getText
+      val params = typeCtx.INTEGER_VALUE().asScala.toList
+      val dtStr =
+        if (params.nonEmpty) s"$badType(${params.mkString(",")})"
+        else badType
+      throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
     }
   }
 
-  override def visitYearMonthIntervalDataType(ctx: 
YearMonthIntervalDataTypeContext): DataType = {
-    val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+  private def visitYearMonthIntervalDataType(ctx: 
NonTrivialPrimitiveTypeContext): DataType = {
+    val startStr = ctx.fromYearMonth.getText.toLowerCase(Locale.ROOT)
     val start = YearMonthIntervalType.stringToField(startStr)
     if (ctx.to != null) {
       val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
@@ -146,8 +161,8 @@ class DataTypeAstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] {
     }
   }
 
-  override def visitDayTimeIntervalDataType(ctx: 
DayTimeIntervalDataTypeContext): DataType = {
-    val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+  private def visitDayTimeIntervalDataType(ctx: 
NonTrivialPrimitiveTypeContext): DataType = {
+    val startStr = ctx.fromDayTime.getText.toLowerCase(Locale.ROOT)
     val start = DayTimeIntervalType.stringToField(startStr)
     if (ctx.to != null) {
       val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
@@ -165,6 +180,9 @@ class DataTypeAstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] {
    * Create a complex DataType. Arrays, Maps and Structures are supported.
    */
   override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = 
withOrigin(ctx) {
+    if (ctx.LT() == null && ctx.NEQ() == null) {
+      throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, 
ctx)
+    }
     ctx.complex.getType match {
       case SqlBaseParser.ARRAY =>
         ArrayType(typedVisit(ctx.dataType(0)))
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index 12f986b89fd2..60ccf7a9282c 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -324,7 +324,9 @@ private[sql] object QueryParsingErrors extends 
DataTypeErrorsBase {
       ctx)
   }
 
-  def charTypeMissingLengthError(dataType: String, ctx: 
PrimitiveDataTypeContext): Throwable = {
+  def charVarcharTypeMissingLengthError(
+      dataType: String,
+      ctx: PrimitiveDataTypeContext): Throwable = {
     new ParseException(
       errorClass = "DATATYPE_MISSING_SIZE",
       messageParameters = Map("type" -> toSQLType(dataType)),
@@ -333,7 +335,7 @@ private[sql] object QueryParsingErrors extends 
DataTypeErrorsBase {
 
   def nestedTypeMissingElementTypeError(
       dataType: String,
-      ctx: PrimitiveDataTypeContext): Throwable = {
+      ctx: ComplexDataTypeContext): Throwable = {
     dataType.toUpperCase(Locale.ROOT) match {
       case "ARRAY" =>
         new ParseException(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to