This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 3da912c4a feat: Add new trait for operator serde (#2115)
3da912c4a is described below
commit 3da912c4ae9845c0605557ccac4bd7156c9360f9
Author: Andy Grove <[email protected]>
AuthorDate: Tue Aug 12 12:32:13 2025 -0600
feat: Add new trait for operator serde (#2115)
---
.../org/apache/comet/serde/CometProject.scala | 52 +++++++++
.../scala/org/apache/comet/serde/CometSort.scala | 58 +++++++++
.../org/apache/comet/serde/QueryPlanSerde.scala | 130 ++++++++++++---------
3 files changed, 186 insertions(+), 54 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometProject.scala
b/spark/src/main/scala/org/apache/comet/serde/CometProject.scala
new file mode 100644
index 000000000..ad48ef27f
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometProject.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.ProjectExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+object CometProject extends CometOperatorSerde[ProjectExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_PROJECT_ENABLED)
+
+ override def convert(
+ op: ProjectExec,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[OperatorOuterClass.Operator] = {
+ val exprs = op.projectList.map(exprToProto(_, op.child.output))
+
+ if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
+ val projectBuilder = OperatorOuterClass.Projection
+ .newBuilder()
+ .addAllProjectList(exprs.map(_.get).asJava)
+ Some(builder.setProjection(projectBuilder).build())
+ } else {
+ withInfo(op, op.projectList: _*)
+ None
+ }
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometSort.scala
b/spark/src/main/scala/org/apache/comet/serde/CometSort.scala
new file mode 100644
index 000000000..5229c7601
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometSort.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.SortExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, supportedSortType}
+
+object CometSort extends CometOperatorSerde[SortExec] {
+
+ override def enabledConfig: Option[ConfigEntry[Boolean]] =
+ Some(CometConf.COMET_EXEC_SORT_ENABLED)
+
+ override def convert(
+ op: SortExec,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[OperatorOuterClass.Operator] = {
+ if (!supportedSortType(op, op.sortOrder)) {
+ withInfo(op, "Unsupported data type in sort expressions")
+ return None
+ }
+
+ val sortOrders = op.sortOrder.map(exprToProto(_, op.child.output))
+
+ if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
+ val sortBuilder = OperatorOuterClass.Sort
+ .newBuilder()
+ .addAllSortOrders(sortOrders.map(_.get).asJava)
+ Some(builder.setSort(sortBuilder).build())
+ } else {
+ withInfo(op, "sort order not supported", op.sortOrder: _*)
+ None
+ }
+ }
+
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6a45b1ca2..35ebabdac 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -47,7 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String
import com.google.protobuf.ByteString
-import org.apache.comet.CometConf
+import org.apache.comet.{CometConf, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
import org.apache.comet.DataTypeSupport.isComplexType
import org.apache.comet.expressions._
@@ -64,6 +64,12 @@ import org.apache.comet.shims.CometExprShim
*/
object QueryPlanSerde extends Logging with CometExprShim {
+ /**
+ * Mapping of Spark operator class to Comet operator handler.
+ */
+ private val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
+ Map(classOf[ProjectExec] -> CometProject, classOf[SortExec] -> CometSort)
+
/**
* Mapping of Spark expression class to Comet expression handler.
*/
@@ -1651,8 +1657,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
*/
def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
val conf = op.conf
- val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
- childOp.foreach(result.addChildren)
+ val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
+ childOp.foreach(builder.addChildren)
op match {
@@ -1669,7 +1675,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
nativeScanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
- result.clearChildren()
+ builder.clearChildren()
if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) &&
CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(conf)) {
@@ -1767,7 +1773,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
}
- Some(result.setNativeScan(nativeScanBuilder).build())
+ Some(builder.setNativeScan(nativeScanBuilder).build())
} else {
// There are unsupported scan type
@@ -1778,19 +1784,6 @@ object QueryPlanSerde extends Logging with CometExprShim
{
None
}
- case ProjectExec(projectList, child) if
CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) =>
- val exprs = projectList.map(exprToProto(_, child.output))
-
- if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
- val projectBuilder = OperatorOuterClass.Projection
- .newBuilder()
- .addAllProjectList(exprs.map(_.get).asJava)
- Some(result.setProjection(projectBuilder).build())
- } else {
- withInfo(op, projectList: _*)
- None
- }
-
case FilterExec(condition, child) if
CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
val cond = exprToProto(condition, child.output)
@@ -1825,29 +1818,12 @@ object QueryPlanSerde extends Logging with
CometExprShim {
.setPredicate(cond.get)
.setUseDatafusionFilter(!containsNativeCometScan(op))
.setWrapChildInCopyExec(wrapChildInCopyExec(condition))
- Some(result.setFilter(filterBuilder).build())
+ Some(builder.setFilter(filterBuilder).build())
} else {
withInfo(op, condition, child)
None
}
- case SortExec(sortOrder, _, child, _) if
CometConf.COMET_EXEC_SORT_ENABLED.get(conf) =>
- if (!supportedSortType(op, sortOrder)) {
- return None
- }
-
- val sortOrders = sortOrder.map(exprToProto(_, child.output))
-
- if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
- val sortBuilder = OperatorOuterClass.Sort
- .newBuilder()
- .addAllSortOrders(sortOrders.map(_.get).asJava)
- Some(result.setSort(sortBuilder).build())
- } else {
- withInfo(op, "sort order not supported", sortOrder: _*)
- None
- }
-
case LocalLimitExec(limit, _) if
CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
if (childOp.nonEmpty) {
// LocalLimit doesn't use offset, but it shares same operator serde
class.
@@ -1856,7 +1832,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
.newBuilder()
.setLimit(limit)
.setOffset(0)
- Some(result.setLimit(limitBuilder).build())
+ Some(builder.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
@@ -1872,7 +1848,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
// When we upgrade to Spark 3.3., we need to address it here.
limitBuilder.setLimit(globalLimitExec.limit)
- Some(result.setLimit(limitBuilder).build())
+ Some(builder.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
@@ -1890,7 +1866,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
.newBuilder()
.addAllProjectList(projExprs.map(_.get).asJava)
.setNumExprPerProject(projections.head.size)
- Some(result.setExpand(expandBuilder).build())
+ Some(builder.setExpand(expandBuilder).build())
} else {
withInfo(op, allProjExprs: _*)
None
@@ -1935,7 +1911,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
- Some(result.setWindow(windowBuilder).build())
+ Some(builder.setWindow(windowBuilder).build())
} else {
None
}
@@ -2002,7 +1978,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
- Some(result.setHashAgg(hashAggBuilder).build())
+ Some(builder.setHashAgg(hashAggBuilder).build())
} else {
val modes = aggregateExpressions.map(_.mode).distinct
@@ -2048,7 +2024,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
- Some(result.setHashAgg(hashAggBuilder).build())
+ Some(builder.setHashAgg(hashAggBuilder).build())
} else {
val allChildren: Seq[Expression] =
groupingExpressions ++ aggregateExpressions ++
aggregateAttributes
@@ -2110,7 +2086,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
.setBuildSide(
if (join.buildSide == BuildLeft) BuildSide.BuildLeft else
BuildSide.BuildRight)
condition.foreach(joinBuilder.setCondition)
- Some(result.setHashJoin(joinBuilder).build())
+ Some(builder.setHashJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
@@ -2200,7 +2176,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
condition.map(joinBuilder.setCondition)
- Some(result.setSortMergeJoin(joinBuilder).build())
+ Some(builder.setSortMergeJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
@@ -2236,9 +2212,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
scanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
- result.clearChildren()
+ builder.clearChildren()
- Some(result.setScan(scanBuilder).build())
+ Some(builder.setScan(scanBuilder).build())
} else {
// There are unsupported scan type
val msg =
@@ -2249,15 +2225,29 @@ object QueryPlanSerde extends Logging with
CometExprShim {
}
case op =>
- // Emit warning if:
- // 1. it is not Spark shuffle operator, which is handled separately
- // 2. it is not a Comet operator
- if (!op.nodeName.contains("Comet") &&
!op.isInstanceOf[ShuffleExchangeExec]) {
- val msg = s"unsupported Spark operator: ${op.nodeName}"
- emitWarning(msg)
- withInfo(op, msg)
+ opSerdeMap.get(op.getClass) match {
+ case Some(handler) =>
+ handler.enabledConfig.foreach { enabledConfig =>
+ if (!enabledConfig.get(op.conf)) {
+ withInfo(
+ op,
+ s"Native support for operator ${op.getClass.getSimpleName}
is disabled. " +
+ s"Set ${enabledConfig.key}=true to enable it.")
+ return None
+ }
+ }
+ handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op,
builder, childOp: _*)
+ case _ =>
+ // Emit warning if:
+ // 1. it is not Spark shuffle operator, which is handled
separately
+ // 2. it is not a Comet operator
+ if (!op.nodeName.contains("Comet") &&
!op.isInstanceOf[ShuffleExchangeExec]) {
+ val msg = s"unsupported Spark operator: ${op.nodeName}"
+ emitWarning(msg)
+ withInfo(op, msg)
+ }
+ None
}
- None
}
}
@@ -2416,6 +2406,38 @@ object QueryPlanSerde extends Logging with CometExprShim
{
}
}
+/**
+ * Trait for providing serialization logic for operators.
+ */
+trait CometOperatorSerde[T <: SparkPlan] {
+
+ /**
+ * Convert a Spark operator into a protocol buffer representation that can
be passed into native
+ * code.
+ *
+ * @param op
+ * The Spark operator.
+ * @param builder
+ * The protobuf builder for the operator.
+ * @param childOp
+ * Child operators that have already been converted to Comet.
+ * @return
+ * Protocol buffer representation, or None if the operator could not be
converted. In this
+ * case it is expected that the input operator will have been tagged with
reasons why it could
+ * not be converted.
+ */
+ def convert(
+ op: T,
+ builder: Operator.Builder,
+ childOp: Operator*): Option[OperatorOuterClass.Operator]
+
+ /**
+ * Get the optional Comet configuration entry that is used to enable or
disable native support
+ * for this operator.
+ */
+ def enabledConfig: Option[ConfigEntry[Boolean]]
+}
+
/**
* Trait for providing serialization logic for expressions.
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]