This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 0dc71c091 chore: Refactor serde for more array and struct expressions
(#2257)
0dc71c091 is described below
commit 0dc71c091d1c10ec5480235ec3acc28e3522fe89
Author: Andy Grove <[email protected]>
AuthorDate: Fri Aug 29 16:59:18 2025 -0600
chore: Refactor serde for more array and struct expressions (#2257)
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 170 +--------------------
.../main/scala/org/apache/comet/serde/arrays.scala | 68 ++++++++-
.../scala/org/apache/comet/serde/structs.scala | 169 ++++++++++++++++++++
3 files changed, 242 insertions(+), 165 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 22bd6fd03..ad9be300f 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -95,6 +95,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArraysOverlap] -> CometArraysOverlap,
classOf[ArrayUnion] -> CometArrayUnion,
classOf[CreateArray] -> CometCreateArray,
+ classOf[GetArrayItem] -> CometGetArrayItem,
+ classOf[ElementAt] -> CometElementAt,
classOf[Ascii] -> CometScalarFunction("ascii"),
classOf[ConcatWs] -> CometScalarFunction("concat_ws"),
classOf[Chr] -> CometScalarFunction("char"),
@@ -170,6 +172,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[DateSub] -> CometDateSub,
classOf[TruncDate] -> CometTruncDate,
classOf[TruncTimestamp] -> CometTruncTimestamp,
+ classOf[CreateNamedStruct] -> CometCreateNamedStruct,
+ classOf[GetStructField] -> CometGetStructField,
+ classOf[GetArrayStructFields] -> CometGetArrayStructFields,
+ classOf[StructsToJson] -> CometStructsToJson,
classOf[Flatten] -> CometFlatten,
classOf[Atan2] -> CometAtan2,
classOf[Ceil] -> CometCeil,
@@ -922,66 +928,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
None
}
- case StructsToJson(options, child, timezoneId) =>
- if (options.nonEmpty) {
- withInfo(expr, "StructsToJson with options is not supported")
- None
- } else {
-
- def isSupportedType(dt: DataType): Boolean = {
- dt match {
- case StructType(fields) =>
- fields.forall(f => isSupportedType(f.dataType))
- case DataTypes.BooleanType | DataTypes.ByteType |
DataTypes.ShortType |
- DataTypes.IntegerType | DataTypes.LongType |
DataTypes.FloatType |
- DataTypes.DoubleType | DataTypes.StringType =>
- true
- case DataTypes.DateType | DataTypes.TimestampType =>
- // TODO implement these types with tests for formatting
options and timezone
- false
- case _: MapType | _: ArrayType =>
- // Spark supports map and array in StructsToJson but this is
not yet
- // implemented in Comet
- false
- case _ => false
- }
- }
-
- val isSupported = child.dataType match {
- case s: StructType =>
- s.fields.forall(f => isSupportedType(f.dataType))
- case _: MapType | _: ArrayType =>
- // Spark supports map and array in StructsToJson but this is not
yet
- // implemented in Comet
- false
- case _ =>
- false
- }
-
- if (isSupported) {
- exprToProtoInternal(child, inputs, binding) match {
- case Some(p) =>
- val toJson = ExprOuterClass.ToJson
- .newBuilder()
- .setChild(p)
- .setTimezone(timezoneId.getOrElse("UTC"))
- .setIgnoreNullFields(true)
- .build()
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setToJson(toJson)
- .build())
- case _ =>
- withInfo(expr, child)
- None
- }
- } else {
- withInfo(expr, "Unsupported data type", child)
- None
- }
- }
-
case SortOrder(child, direction, nullOrdering, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
@@ -1336,110 +1282,6 @@ object QueryPlanSerde extends Logging with
CometExprShim {
withInfo(expr, bloomFilter, value)
None
}
-
- case struct @ CreateNamedStruct(_) =>
- if (struct.names.length != struct.names.distinct.length) {
- withInfo(expr, "CreateNamedStruct with duplicate field names are not
supported")
- return None
- }
-
- val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs,
binding))
-
- if (valExprs.forall(_.isDefined)) {
- val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
- structBuilder.addAllValues(valExprs.map(_.get).asJava)
- structBuilder.addAllNames(struct.names.map(_.toString).asJava)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setCreateNamedStruct(structBuilder)
- .build())
- } else {
- withInfo(expr, "unsupported arguments for CreateNamedStruct",
struct.valExprs: _*)
- None
- }
-
- case GetStructField(child, ordinal, _) =>
- exprToProtoInternal(child, inputs, binding).map { childExpr =>
- val getStructFieldBuilder = ExprOuterClass.GetStructField
- .newBuilder()
- .setChild(childExpr)
- .setOrdinal(ordinal)
-
- ExprOuterClass.Expr
- .newBuilder()
- .setGetStructField(getStructFieldBuilder)
- .build()
- }
-
- case GetArrayItem(child, ordinal, failOnError) =>
- val childExpr = exprToProtoInternal(child, inputs, binding)
- val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
-
- if (childExpr.isDefined && ordinalExpr.isDefined) {
- val listExtractBuilder = ExprOuterClass.ListExtract
- .newBuilder()
- .setChild(childExpr.get)
- .setOrdinal(ordinalExpr.get)
- .setOneBased(false)
- .setFailOnError(failOnError)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setListExtract(listExtractBuilder)
- .build())
- } else {
- withInfo(expr, "unsupported arguments for GetArrayItem", child,
ordinal)
- None
- }
-
- case ElementAt(child, ordinal, defaultValue, failOnError)
- if child.dataType.isInstanceOf[ArrayType] =>
- val childExpr = exprToProtoInternal(child, inputs, binding)
- val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
- val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs,
binding))
-
- if (childExpr.isDefined && ordinalExpr.isDefined &&
- defaultExpr.isDefined == defaultValue.isDefined) {
- val arrayExtractBuilder = ExprOuterClass.ListExtract
- .newBuilder()
- .setChild(childExpr.get)
- .setOrdinal(ordinalExpr.get)
- .setOneBased(true)
- .setFailOnError(failOnError)
-
- defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setListExtract(arrayExtractBuilder)
- .build())
- } else {
- withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
- None
- }
-
- case GetArrayStructFields(child, _, ordinal, _, _) =>
- val childExpr = exprToProtoInternal(child, inputs, binding)
-
- if (childExpr.isDefined) {
- val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
- .newBuilder()
- .setChild(childExpr.get)
- .setOrdinal(ordinal)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setGetArrayStructFields(arrayStructFieldsBuilder)
- .build())
- } else {
- withInfo(expr, "unsupported arguments for GetArrayStructFields",
child)
- None
- }
case af @ ArrayFilter(_, func) if
func.children.head.isInstanceOf[IsNotNull] =>
convert(af, CometArrayCompact)
case expr =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index 411ef00b4..5b1603aaf 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
import scala.annotation.tailrec
-import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains,
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax,
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute,
CreateArray, Expression, Flatten, Literal}
+import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains,
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax,
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute,
CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -404,6 +404,72 @@ object CometCreateArray extends
CometExpressionSerde[CreateArray] {
}
}
+object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] {
+ override def convert(
+ expr: GetArrayItem,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+ val ordinalExpr = exprToProtoInternal(expr.ordinal, inputs, binding)
+
+ if (childExpr.isDefined && ordinalExpr.isDefined) {
+ val listExtractBuilder = ExprOuterClass.ListExtract
+ .newBuilder()
+ .setChild(childExpr.get)
+ .setOrdinal(ordinalExpr.get)
+ .setOneBased(false)
+ .setFailOnError(expr.failOnError)
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setListExtract(listExtractBuilder)
+ .build())
+ } else {
+ withInfo(expr, "unsupported arguments for GetArrayItem", expr.child,
expr.ordinal)
+ None
+ }
+ }
+}
+
+object CometElementAt extends CometExpressionSerde[ElementAt] {
+
+ override def convert(
+ expr: ElementAt,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.left, inputs, binding)
+ val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
+ val defaultExpr =
expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))
+
+ if (!expr.left.dataType.isInstanceOf[ArrayType]) {
+ withInfo(expr, "Input is not an array")
+ return None
+ }
+
+ if (childExpr.isDefined && ordinalExpr.isDefined &&
+ defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
+ val arrayExtractBuilder = ExprOuterClass.ListExtract
+ .newBuilder()
+ .setChild(childExpr.get)
+ .setOrdinal(ordinalExpr.get)
+ .setOneBased(true)
+ .setFailOnError(expr.failOnError)
+
+ defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setListExtract(arrayExtractBuilder)
+ .build())
+ } else {
+ withInfo(expr, "unsupported arguments for ElementAt", expr.left,
expr.right)
+ None
+ }
+ }
+}
+
object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {
override def convert(
diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala
b/spark/src/main/scala/org/apache/comet/serde/structs.scala
new file mode 100644
index 000000000..1c25d87bb
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
CreateNamedStruct, GetArrayStructFields, GetStructField, StructsToJson}
+import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType,
StructType}
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
+
+object CometCreateNamedStruct extends CometExpressionSerde[CreateNamedStruct] {
+ override def convert(
+ expr: CreateNamedStruct,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ if (expr.names.length != expr.names.distinct.length) {
+ withInfo(expr, "CreateNamedStruct with duplicate field names are not
supported")
+ return None
+ }
+
+ val valExprs = expr.valExprs.map(exprToProtoInternal(_, inputs, binding))
+
+ if (valExprs.forall(_.isDefined)) {
+ val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
+ structBuilder.addAllValues(valExprs.map(_.get).asJava)
+ structBuilder.addAllNames(expr.names.map(_.toString).asJava)
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setCreateNamedStruct(structBuilder)
+ .build())
+ } else {
+ withInfo(expr, "unsupported arguments for CreateNamedStruct",
expr.valExprs: _*)
+ None
+ }
+
+ }
+}
+
+object CometGetStructField extends CometExpressionSerde[GetStructField] {
+ override def convert(
+ expr: GetStructField,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ exprToProtoInternal(expr.child, inputs, binding).map { childExpr =>
+ val getStructFieldBuilder = ExprOuterClass.GetStructField
+ .newBuilder()
+ .setChild(childExpr)
+ .setOrdinal(expr.ordinal)
+
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setGetStructField(getStructFieldBuilder)
+ .build()
+ }
+ }
+}
+
+object CometGetArrayStructFields extends
CometExpressionSerde[GetArrayStructFields] {
+ override def convert(
+ expr: GetArrayStructFields,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+
+ if (childExpr.isDefined) {
+ val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+ .newBuilder()
+ .setChild(childExpr.get)
+ .setOrdinal(expr.ordinal)
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setGetArrayStructFields(arrayStructFieldsBuilder)
+ .build())
+ } else {
+ withInfo(expr, "unsupported arguments for GetArrayStructFields",
expr.child)
+ None
+ }
+ }
+}
+
+object CometStructsToJson extends CometExpressionSerde[StructsToJson] {
+
+ override def convert(
+ expr: StructsToJson,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ if (expr.options.nonEmpty) {
+ withInfo(expr, "StructsToJson with options is not supported")
+ None
+ } else {
+
+ def isSupportedType(dt: DataType): Boolean = {
+ dt match {
+ case StructType(fields) =>
+ fields.forall(f => isSupportedType(f.dataType))
+ case DataTypes.BooleanType | DataTypes.ByteType |
DataTypes.ShortType |
+ DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType
|
+ DataTypes.DoubleType | DataTypes.StringType =>
+ true
+ case DataTypes.DateType | DataTypes.TimestampType =>
+ // TODO implement these types with tests for formatting options
and timezone
+ false
+ case _: MapType | _: ArrayType =>
+ // Spark supports map and array in StructsToJson but this is not
yet
+ // implemented in Comet
+ false
+ case _ => false
+ }
+ }
+
+ val isSupported = expr.child.dataType match {
+ case s: StructType =>
+ s.fields.forall(f => isSupportedType(f.dataType))
+ case _: MapType | _: ArrayType =>
+ // Spark supports map and array in StructsToJson but this is not yet
+ // implemented in Comet
+ false
+ case _ =>
+ false
+ }
+
+ if (isSupported) {
+ exprToProtoInternal(expr.child, inputs, binding) match {
+ case Some(p) =>
+ val toJson = ExprOuterClass.ToJson
+ .newBuilder()
+ .setChild(p)
+ .setTimezone(expr.timeZoneId.getOrElse("UTC"))
+ .setIgnoreNullFields(true)
+ .build()
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setToJson(toJson)
+ .build())
+ case _ =>
+ withInfo(expr, expr.child)
+ None
+ }
+ } else {
+ withInfo(expr, "Unsupported data type", expr.child)
+ None
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]