This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a0192546617 [SPARK-53354][CONNECT] Simplify 
LiteralValueProtoConverter.toCatalystStruct
2a0192546617 is described below

commit 2a019254661723428ab80f5c7609021f8f8cc616
Author: Yihong He <heyihong...@gmail.com>
AuthorDate: Sun Aug 24 10:04:48 2025 +0800

    [SPARK-53354][CONNECT] Simplify LiteralValueProtoConverter.toCatalystStruct
    
    ### What changes were proposed in this pull request?
    
    This PR simplifies the `LiteralValueProtoConverter.toCatalystStruct` method 
by refactoring the struct conversion logic to be more straightforward and 
maintainable. The main changes include:
    
    1. **Simplified return type**: Changed `toCatalystStruct` to return only 
the converted struct value instead of a tuple `(Any, proto.DataType.Struct)`
    2. **Extracted struct type resolution**: Created a new `getProtoStructType` 
method to handle struct type resolution separately
    3. **Simplified internal conversion**: Introduced 
`toCatalystStructInternal` method that takes the struct type as a parameter
    4. **Removed complex type inference logic**: Eliminated the 
`LiteralValueWithDataType` case class and simplified the `getConverter` method 
by removing the `inferDataType` parameter
    5. **Enhanced recursive type inference**: Improved the 
`getInferredDataType` method to support recursive type inference for struct 
fields
    
    ### Why are the changes needed?
    
    The original method is a bit overly complex with multiple code paths and 
conditional logic that made it difficult to understand and maintain. This 
refactoring improves code readability while preserving the same functionality.
    
    ### Does this PR introduce _any_ user-facing change?
    
    **No**. This is a pure refactoring that maintains the same external 
behavior and API.
    
    ### How was this patch tested?
    
    `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.4.5
    
    Closes #52098 from heyihong/SPARK-53354.
    
    Authored-by: Yihong He <heyihong...@gmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../common/LiteralValueProtoConverter.scala        | 145 ++++++++-------------
 .../planner/LiteralExpressionProtoConverter.scala  |   6 +-
 .../LiteralExpressionProtoConverterSuite.scala     |   5 +-
 3 files changed, 63 insertions(+), 93 deletions(-)

diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index a4d8b0f2a02d..293ffe17bb4f 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -320,7 +320,7 @@ object LiteralValueProtoConverter {
         toCatalystArray(literal.getArray)
 
       case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        toCatalystStruct(literal.getStruct)._1
+        toCatalystStruct(literal.getStruct)
 
       case other =>
         throw new UnsupportedOperationException(
@@ -328,9 +328,7 @@ object LiteralValueProtoConverter {
     }
   }
 
-  private def getConverter(
-      dataType: proto.DataType,
-      inferDataType: Boolean = false): proto.Expression.Literal => Any = {
+  private def getConverter(dataType: proto.DataType): proto.Expression.Literal 
=> Any = {
     dataType.getKindCase match {
       case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
       case proto.DataType.KindCase.INTEGER => v => v.getInteger
@@ -354,20 +352,15 @@ object LiteralValueProtoConverter {
       case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
       case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
       case proto.DataType.KindCase.STRUCT =>
-        if (inferDataType) { v =>
-          val (struct, structType) = toCatalystStruct(v.getStruct, None)
-          LiteralValueWithDataType(
-            struct,
-            proto.DataType.newBuilder.setStruct(structType).build())
-        } else { v =>
-          toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
-        }
+        v => toCatalystStructInternal(v.getStruct, dataType.getStruct)
       case _ =>
         throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
     }
   }
 
-  private def getInferredDataType(literal: proto.Expression.Literal): 
Option[proto.DataType] = {
+  private def getInferredDataType(
+      literal: proto.Expression.Literal,
+      recursive: Boolean = false): Option[proto.DataType] = {
     if (literal.hasNull) {
       return Some(literal.getNull)
     }
@@ -399,8 +392,31 @@ object LiteralValueProtoConverter {
       case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
         
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
       case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        // The type of the fields will be inferred from the literals of the 
fields in the struct.
-        builder.setStruct(literal.getStruct.getStructType.getStruct)
+        if (recursive) {
+          val structType = literal.getStruct.getDataTypeStruct
+          val structData = literal.getStruct.getElementsList.asScala
+          val structTypeBuilder = proto.DataType.Struct.newBuilder
+          for ((element, field) <- 
structData.zip(structType.getFieldsList.asScala)) {
+            if (field.hasDataType) {
+              structTypeBuilder.addFields(field)
+            } else {
+              getInferredDataType(element, recursive = true) match {
+                case Some(dataType) =>
+                  val fieldBuilder = structTypeBuilder.addFieldsBuilder()
+                  fieldBuilder.setName(field.getName)
+                  fieldBuilder.setDataType(dataType)
+                  fieldBuilder.setNullable(field.getNullable)
+                  if (field.hasMetadata) {
+                    fieldBuilder.setMetadata(field.getMetadata)
+                  }
+                case None => return None
+              }
+            }
+          }
+          builder.setStruct(structTypeBuilder.build())
+        } else {
+          builder.setStruct(proto.DataType.Struct.newBuilder.build())
+        }
       case _ =>
         // Not all data types support inferring the data type from the literal 
at the moment.
         // e.g. the type of DayTimeInterval contains extra information like 
start_field and
@@ -410,13 +426,6 @@ object LiteralValueProtoConverter {
     Some(builder.build())
   }
 
-  private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): 
proto.DataType = {
-    getInferredDataType(literal).getOrElse {
-      throw InvalidPlanInput(
-        s"Unsupported Literal type for data type inference: 
${literal.getLiteralTypeCase}")
-    }
-  }
-
   def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = {
     def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
         tag: ClassTag[T]): Array[T] = {
@@ -451,9 +460,9 @@ object LiteralValueProtoConverter {
     makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
   }
 
-  def toCatalystStruct(
+  private def toCatalystStructInternal(
       struct: proto.Expression.Literal.Struct,
-      structTypeOpt: Option[proto.DataType.Struct] = None): (Any, 
proto.DataType.Struct) = {
+      structType: proto.DataType.Struct): Any = {
     def toTuple[A <: Object](data: Seq[A]): Product = {
       try {
         val tupleClass = 
SparkClassUtils.classForName(s"scala.Tuple${data.length}")
@@ -464,78 +473,36 @@ object LiteralValueProtoConverter {
       }
     }
 
-    if (struct.hasDataTypeStruct) {
-      // The new way to define and convert structs.
-      val (structData, structType) = if (structTypeOpt.isDefined) {
-        val structFields = structTypeOpt.get.getFieldsList.asScala
-        val structData =
-          struct.getElementsList.asScala.zip(structFields).map { case 
(element, structField) =>
-            getConverter(structField.getDataType)(element)
-          }
-        (structData, structTypeOpt.get)
-      } else {
-        def protoStructField(
-            name: String,
-            dataType: proto.DataType,
-            nullable: Boolean,
-            metadata: Option[String]): proto.DataType.StructField = {
-          val builder = proto.DataType.StructField
-            .newBuilder()
-            .setName(name)
-            .setDataType(dataType)
-            .setNullable(nullable)
-          metadata.foreach(builder.setMetadata)
-          builder.build()
-        }
-
-        val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
-
-        val structDataAndFields = 
struct.getElementsList.asScala.zip(dataTypeFields).map {
-          case (element, dataTypeField) =>
-            if (dataTypeField.hasDataType) {
-              (getConverter(dataTypeField.getDataType)(element), dataTypeField)
-            } else {
-              val outerDataType = getInferredDataTypeOrThrow(element)
-              val (value, dataType) =
-                getConverter(outerDataType, inferDataType = true)(element) 
match {
-                  case LiteralValueWithDataType(value, dataType) => (value, 
dataType)
-                  case value => (value, outerDataType)
-                }
-              (
-                value,
-                protoStructField(
-                  dataTypeField.getName,
-                  dataType,
-                  dataTypeField.getNullable,
-                  if (dataTypeField.hasMetadata) 
Some(dataTypeField.getMetadata) else None))
-            }
-        }
+    val elements = struct.getElementsList.asScala
+    val dataTypes = structType.getFieldsList.asScala.map(_.getDataType)
+    val structData = elements
+      .zip(dataTypes)
+      .map { case (element, dataType) =>
+        getConverter(dataType)(element)
+      }
+      .asInstanceOf[scala.collection.Seq[Object]]
+      .toSeq
 
-        val structType = proto.DataType.Struct
-          .newBuilder()
-          .addAllFields(structDataAndFields.map(_._2).asJava)
-          .build()
+    toTuple(structData)
+  }
 
-        (structDataAndFields.map(_._1), structType)
+  def getProtoStructType(struct: proto.Expression.Literal.Struct): 
proto.DataType.Struct = {
+    if (struct.hasDataTypeStruct) {
+      val literal = 
proto.Expression.Literal.newBuilder().setStruct(struct).build()
+      getInferredDataType(literal, recursive = true) match {
+        case Some(dataType) => dataType.getStruct
+        case None => throw InvalidPlanInput("Cannot infer data type from this 
struct literal.")
       }
-      (toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
     } else if (struct.hasStructType) {
-      // For backward compatibility, we still support the old way to define 
and convert structs.
-      val elements = struct.getElementsList.asScala
-      val dataTypes = 
struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
-      val structData = elements
-        .zip(dataTypes)
-        .map { case (element, dataType) =>
-          getConverter(dataType)(element)
-        }
-        .asInstanceOf[scala.collection.Seq[Object]]
-        .toSeq
-
-      (toTuple(structData), struct.getStructType.getStruct)
+      // For backward compatibility, we still support the old way to
+      // define and convert struct types.
+      struct.getStructType.getStruct
     } else {
       throw InvalidPlanInput("Data type information is missing in the struct 
literal.")
     }
   }
 
-  private case class LiteralValueWithDataType(value: Any, dataType: 
proto.DataType)
+  def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
+    toCatalystStructInternal(struct, getProtoStructType(struct))
+  }
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
index 10f046a57da9..f4c56d461bd2 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
@@ -117,9 +117,11 @@ object LiteralExpressionProtoConverter {
             DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType)))
 
       case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        val (structData, structType) = 
LiteralValueProtoConverter.toCatalystStruct(lit.getStruct)
+        val structData = 
LiteralValueProtoConverter.toCatalystStruct(lit.getStruct)
         val dataType = DataTypeProtoConverter.toCatalystType(
-          proto.DataType.newBuilder.setStruct(structType).build())
+          proto.DataType.newBuilder
+            
.setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
+            .build())
         val convert = 
CatalystTypeConverters.createToCatalystConverter(dataType)
         expressions.Literal(convert(structData), dataType)
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
index 559984e47cf8..71fcd2b39492 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
@@ -99,7 +99,8 @@ class LiteralExpressionProtoConverterSuite extends 
AnyFunSuite { // scalastyle:i
       .addElements(LiteralValueProtoConverter.toLiteralProto("test"))
       .build()
 
-    val (result, resultType) = 
LiteralValueProtoConverter.toCatalystStruct(structProto)
+    val result = LiteralValueProtoConverter.toCatalystStruct(structProto)
+    val resultType = LiteralValueProtoConverter.getProtoStructType(structProto)
 
     // Verify the result is a tuple with correct values
     assert(result.isInstanceOf[Product])
@@ -156,7 +157,7 @@ class LiteralExpressionProtoConverterSuite extends 
AnyFunSuite { // scalastyle:i
     assert(!structFields.get(1).getNullable)
     assert(!structFields.get(1).hasMetadata)
 
-    val (_, structTypeProto) = 
LiteralValueProtoConverter.toCatalystStruct(literalProto.getStruct)
+    val structTypeProto = 
LiteralValueProtoConverter.getProtoStructType(literalProto.getStruct)
     assert(structTypeProto.getFieldsList.get(0).getNullable)
     assert(structTypeProto.getFieldsList.get(0).hasMetadata)
     assert(structTypeProto.getFieldsList.get(0).getMetadata == 
"""{"key":"value"}""")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to