This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2a0192546617 [SPARK-53354][CONNECT] Simplify LiteralValueProtoConverter.toCatalystStruct 2a0192546617 is described below commit 2a019254661723428ab80f5c7609021f8f8cc616 Author: Yihong He <heyihong...@gmail.com> AuthorDate: Sun Aug 24 10:04:48 2025 +0800 [SPARK-53354][CONNECT] Simplify LiteralValueProtoConverter.toCatalystStruct ### What changes were proposed in this pull request? This PR simplifies the `LiteralValueProtoConverter.toCatalystStruct` method by refactoring the struct conversion logic to be more straightforward and maintainable. The main changes include: 1. **Simplified return type**: Changed `toCatalystStruct` to return only the converted struct value instead of a tuple `(Any, proto.DataType.Struct)` 2. **Extracted struct type resolution**: Created a new `getProtoStructType` method to handle struct type resolution separately 3. **Simplified internal conversion**: Introduced `toCatalystStructInternal` method that takes the struct type as a parameter 4. **Removed complex type inference logic**: Eliminated the `LiteralValueWithDataType` case class and simplified the `getConverter` method by removing the `inferDataType` parameter 5. **Enhanced recursive type inference**: Improved the `getInferredDataType` method to support recursive type inference for struct fields ### Why are the changes needed? The original method is a bit overly complex with multiple code paths and conditional logic that made it difficult to understand and maintain. This refactoring improves code readability while preserving the same functionality. ### Does this PR introduce _any_ user-facing change? **No**. This is a pure refactoring that maintains the same external behavior and API. ### How was this patch tested? `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.4.5 Closes #52098 from heyihong/SPARK-53354. Authored-by: Yihong He <heyihong...@gmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../common/LiteralValueProtoConverter.scala | 145 ++++++++------------- .../planner/LiteralExpressionProtoConverter.scala | 6 +- .../LiteralExpressionProtoConverterSuite.scala | 5 +- 3 files changed, 63 insertions(+), 93 deletions(-) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index a4d8b0f2a02d..293ffe17bb4f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -320,7 +320,7 @@ object LiteralValueProtoConverter { toCatalystArray(literal.getArray) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - toCatalystStruct(literal.getStruct)._1 + toCatalystStruct(literal.getStruct) case other => throw new UnsupportedOperationException( @@ -328,9 +328,7 @@ object LiteralValueProtoConverter { } } - private def getConverter( - dataType: proto.DataType, - inferDataType: Boolean = false): proto.Expression.Literal => Any = { + private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { dataType.getKindCase match { case proto.DataType.KindCase.SHORT => v => v.getShort.toShort case proto.DataType.KindCase.INTEGER => v => v.getInteger @@ -354,20 +352,15 @@ object LiteralValueProtoConverter { case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray) case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap) case proto.DataType.KindCase.STRUCT => - if (inferDataType) { v => - val (struct, structType) = toCatalystStruct(v.getStruct, None) - LiteralValueWithDataType( - struct, - proto.DataType.newBuilder.setStruct(structType).build()) - } else { v => - toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1 - } + v => toCatalystStructInternal(v.getStruct, dataType.getStruct) case _ => throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") } } - private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = { + private def getInferredDataType( + literal: proto.Expression.Literal, + recursive: Boolean = false): Option[proto.DataType] = { if (literal.hasNull) { return Some(literal.getNull) } @@ -399,8 +392,31 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build()) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - // The type of the fields will be inferred from the literals of the fields in the struct. - builder.setStruct(literal.getStruct.getStructType.getStruct) + if (recursive) { + val structType = literal.getStruct.getDataTypeStruct + val structData = literal.getStruct.getElementsList.asScala + val structTypeBuilder = proto.DataType.Struct.newBuilder + for ((element, field) <- structData.zip(structType.getFieldsList.asScala)) { + if (field.hasDataType) { + structTypeBuilder.addFields(field) + } else { + getInferredDataType(element, recursive = true) match { + case Some(dataType) => + val fieldBuilder = structTypeBuilder.addFieldsBuilder() + fieldBuilder.setName(field.getName) + fieldBuilder.setDataType(dataType) + fieldBuilder.setNullable(field.getNullable) + if (field.hasMetadata) { + fieldBuilder.setMetadata(field.getMetadata) + } + case None => return None + } + } + } + builder.setStruct(structTypeBuilder.build()) + } else { + builder.setStruct(proto.DataType.Struct.newBuilder.build()) + } case _ => // Not all data types support inferring the data type from the literal at the moment. // e.g. the type of DayTimeInterval contains extra information like start_field and @@ -410,13 +426,6 @@ object LiteralValueProtoConverter { Some(builder.build()) } - private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = { - getInferredDataType(literal).getOrElse { - throw InvalidPlanInput( - s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}") - } - } - def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit tag: ClassTag[T]): Array[T] = { @@ -451,9 +460,9 @@ object LiteralValueProtoConverter { makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) } - def toCatalystStruct( + private def toCatalystStructInternal( struct: proto.Expression.Literal.Struct, - structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = { + structType: proto.DataType.Struct): Any = { def toTuple[A <: Object](data: Seq[A]): Product = { try { val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}") @@ -464,78 +473,36 @@ object LiteralValueProtoConverter { } } - if (struct.hasDataTypeStruct) { - // The new way to define and convert structs. - val (structData, structType) = if (structTypeOpt.isDefined) { - val structFields = structTypeOpt.get.getFieldsList.asScala - val structData = - struct.getElementsList.asScala.zip(structFields).map { case (element, structField) => - getConverter(structField.getDataType)(element) - } - (structData, structTypeOpt.get) - } else { - def protoStructField( - name: String, - dataType: proto.DataType, - nullable: Boolean, - metadata: Option[String]): proto.DataType.StructField = { - val builder = proto.DataType.StructField - .newBuilder() - .setName(name) - .setDataType(dataType) - .setNullable(nullable) - metadata.foreach(builder.setMetadata) - builder.build() - } - - val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala - - val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map { - case (element, dataTypeField) => - if (dataTypeField.hasDataType) { - (getConverter(dataTypeField.getDataType)(element), dataTypeField) - } else { - val outerDataType = getInferredDataTypeOrThrow(element) - val (value, dataType) = - getConverter(outerDataType, inferDataType = true)(element) match { - case LiteralValueWithDataType(value, dataType) => (value, dataType) - case value => (value, outerDataType) - } - ( - value, - protoStructField( - dataTypeField.getName, - dataType, - dataTypeField.getNullable, - if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None)) - } - } + val elements = struct.getElementsList.asScala + val dataTypes = structType.getFieldsList.asScala.map(_.getDataType) + val structData = elements + .zip(dataTypes) + .map { case (element, dataType) => + getConverter(dataType)(element) + } + .asInstanceOf[scala.collection.Seq[Object]] + .toSeq - val structType = proto.DataType.Struct - .newBuilder() - .addAllFields(structDataAndFields.map(_._2).asJava) - .build() + toTuple(structData) + } - (structDataAndFields.map(_._1), structType) + def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = { + if (struct.hasDataTypeStruct) { + val literal = proto.Expression.Literal.newBuilder().setStruct(struct).build() + getInferredDataType(literal, recursive = true) match { + case Some(dataType) => dataType.getStruct + case None => throw InvalidPlanInput("Cannot infer data type from this struct literal.") } - (toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType) } else if (struct.hasStructType) { - // For backward compatibility, we still support the old way to define and convert structs. - val elements = struct.getElementsList.asScala - val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType) - val structData = elements - .zip(dataTypes) - .map { case (element, dataType) => - getConverter(dataType)(element) - } - .asInstanceOf[scala.collection.Seq[Object]] - .toSeq - - (toTuple(structData), struct.getStructType.getStruct) + // For backward compatibility, we still support the old way to + // define and convert struct types. + struct.getStructType.getStruct } else { throw InvalidPlanInput("Data type information is missing in the struct literal.") } } - private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType) + def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { + toCatalystStructInternal(struct, getProtoStructType(struct)) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index 10f046a57da9..f4c56d461bd2 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -117,9 +117,11 @@ object LiteralExpressionProtoConverter { DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType))) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val (structData, structType) = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) + val structData = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) val dataType = DataTypeProtoConverter.toCatalystType( - proto.DataType.newBuilder.setStruct(structType).build()) + proto.DataType.newBuilder + .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct)) + .build()) val convert = CatalystTypeConverters.createToCatalystConverter(dataType) expressions.Literal(convert(structData), dataType) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 559984e47cf8..71fcd2b39492 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -99,7 +99,8 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i .addElements(LiteralValueProtoConverter.toLiteralProto("test")) .build() - val (result, resultType) = LiteralValueProtoConverter.toCatalystStruct(structProto) + val result = LiteralValueProtoConverter.toCatalystStruct(structProto) + val resultType = LiteralValueProtoConverter.getProtoStructType(structProto) // Verify the result is a tuple with correct values assert(result.isInstanceOf[Product]) @@ -156,7 +157,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(!structFields.get(1).getNullable) assert(!structFields.get(1).hasMetadata) - val (_, structTypeProto) = LiteralValueProtoConverter.toCatalystStruct(literalProto.getStruct) + val structTypeProto = LiteralValueProtoConverter.getProtoStructType(literalProto.getStruct) assert(structTypeProto.getFieldsList.get(0).getNullable) assert(structTypeProto.getFieldsList.get(0).hasMetadata) assert(structTypeProto.getFieldsList.get(0).getMetadata == """{"key":"value"}""") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org