This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 3da912c4a feat: Add new trait for operator serde (#2115)
3da912c4a is described below

commit 3da912c4ae9845c0605557ccac4bd7156c9360f9
Author: Andy Grove <[email protected]>
AuthorDate: Tue Aug 12 12:32:13 2025 -0600

    feat: Add new trait for operator serde (#2115)
---
 .../org/apache/comet/serde/CometProject.scala      |  52 +++++++++
 .../scala/org/apache/comet/serde/CometSort.scala   |  58 +++++++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 130 ++++++++++++---------
 3 files changed, 186 insertions(+), 54 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/CometProject.scala 
b/spark/src/main/scala/org/apache/comet/serde/CometProject.scala
new file mode 100644
index 000000000..ad48ef27f
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometProject.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.ProjectExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.exprToProto
+
+object CometProject extends CometOperatorSerde[ProjectExec] {
+
+  override def enabledConfig: Option[ConfigEntry[Boolean]] =
+    Some(CometConf.COMET_EXEC_PROJECT_ENABLED)
+
+  override def convert(
+      op: ProjectExec,
+      builder: Operator.Builder,
+      childOp: Operator*): Option[OperatorOuterClass.Operator] = {
+    val exprs = op.projectList.map(exprToProto(_, op.child.output))
+
+    if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
+      val projectBuilder = OperatorOuterClass.Projection
+        .newBuilder()
+        .addAllProjectList(exprs.map(_.get).asJava)
+      Some(builder.setProjection(projectBuilder).build())
+    } else {
+      withInfo(op, op.projectList: _*)
+      None
+    }
+  }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/CometSort.scala 
b/spark/src/main/scala/org/apache/comet/serde/CometSort.scala
new file mode 100644
index 000000000..5229c7601
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/CometSort.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.SortExec
+
+import org.apache.comet.{CometConf, ConfigEntry}
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, supportedSortType}
+
+object CometSort extends CometOperatorSerde[SortExec] {
+
+  override def enabledConfig: Option[ConfigEntry[Boolean]] =
+    Some(CometConf.COMET_EXEC_SORT_ENABLED)
+
+  override def convert(
+      op: SortExec,
+      builder: Operator.Builder,
+      childOp: Operator*): Option[OperatorOuterClass.Operator] = {
+    if (!supportedSortType(op, op.sortOrder)) {
+      withInfo(op, "Unsupported data type in sort expressions")
+      return None
+    }
+
+    val sortOrders = op.sortOrder.map(exprToProto(_, op.child.output))
+
+    if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
+      val sortBuilder = OperatorOuterClass.Sort
+        .newBuilder()
+        .addAllSortOrders(sortOrders.map(_.get).asJava)
+      Some(builder.setSort(sortBuilder).build())
+    } else {
+      withInfo(op, "sort order not supported", op.sortOrder: _*)
+      None
+    }
+  }
+
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6a45b1ca2..35ebabdac 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -47,7 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String
 
 import com.google.protobuf.ByteString
 
-import org.apache.comet.CometConf
+import org.apache.comet.{CometConf, ConfigEntry}
 import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
 import org.apache.comet.DataTypeSupport.isComplexType
 import org.apache.comet.expressions._
@@ -64,6 +64,12 @@ import org.apache.comet.shims.CometExprShim
  */
 object QueryPlanSerde extends Logging with CometExprShim {
 
+  /**
+   * Mapping of Spark operator class to Comet operator handler.
+   */
+  private val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
+    Map(classOf[ProjectExec] -> CometProject, classOf[SortExec] -> CometSort)
+
   /**
    * Mapping of Spark expression class to Comet expression handler.
    */
@@ -1651,8 +1657,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
    */
   def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
     val conf = op.conf
-    val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
-    childOp.foreach(result.addChildren)
+    val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
+    childOp.foreach(builder.addChildren)
 
     op match {
 
@@ -1669,7 +1675,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
           nativeScanBuilder.addAllFields(scanTypes.asJava)
 
           // Sink operators don't have children
-          result.clearChildren()
+          builder.clearChildren()
 
           if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) &&
             CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(conf)) {
@@ -1767,7 +1773,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             }
           }
 
-          Some(result.setNativeScan(nativeScanBuilder).build())
+          Some(builder.setNativeScan(nativeScanBuilder).build())
 
         } else {
           // There are unsupported scan type
@@ -1778,19 +1784,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           None
         }
 
-      case ProjectExec(projectList, child) if 
CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) =>
-        val exprs = projectList.map(exprToProto(_, child.output))
-
-        if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
-          val projectBuilder = OperatorOuterClass.Projection
-            .newBuilder()
-            .addAllProjectList(exprs.map(_.get).asJava)
-          Some(result.setProjection(projectBuilder).build())
-        } else {
-          withInfo(op, projectList: _*)
-          None
-        }
-
       case FilterExec(condition, child) if 
CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
         val cond = exprToProto(condition, child.output)
 
@@ -1825,29 +1818,12 @@ object QueryPlanSerde extends Logging with 
CometExprShim {
             .setPredicate(cond.get)
             .setUseDatafusionFilter(!containsNativeCometScan(op))
             .setWrapChildInCopyExec(wrapChildInCopyExec(condition))
-          Some(result.setFilter(filterBuilder).build())
+          Some(builder.setFilter(filterBuilder).build())
         } else {
           withInfo(op, condition, child)
           None
         }
 
-      case SortExec(sortOrder, _, child, _) if 
CometConf.COMET_EXEC_SORT_ENABLED.get(conf) =>
-        if (!supportedSortType(op, sortOrder)) {
-          return None
-        }
-
-        val sortOrders = sortOrder.map(exprToProto(_, child.output))
-
-        if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
-          val sortBuilder = OperatorOuterClass.Sort
-            .newBuilder()
-            .addAllSortOrders(sortOrders.map(_.get).asJava)
-          Some(result.setSort(sortBuilder).build())
-        } else {
-          withInfo(op, "sort order not supported", sortOrder: _*)
-          None
-        }
-
       case LocalLimitExec(limit, _) if 
CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
         if (childOp.nonEmpty) {
           // LocalLimit doesn't use offset, but it shares same operator serde 
class.
@@ -1856,7 +1832,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             .newBuilder()
             .setLimit(limit)
             .setOffset(0)
-          Some(result.setLimit(limitBuilder).build())
+          Some(builder.setLimit(limitBuilder).build())
         } else {
           withInfo(op, "No child operator")
           None
@@ -1872,7 +1848,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
           // When we upgrade to Spark 3.3., we need to address it here.
           limitBuilder.setLimit(globalLimitExec.limit)
 
-          Some(result.setLimit(limitBuilder).build())
+          Some(builder.setLimit(limitBuilder).build())
         } else {
           withInfo(op, "No child operator")
           None
@@ -1890,7 +1866,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             .newBuilder()
             .addAllProjectList(projExprs.map(_.get).asJava)
             .setNumExprPerProject(projections.head.size)
-          Some(result.setExpand(expandBuilder).build())
+          Some(builder.setExpand(expandBuilder).build())
         } else {
           withInfo(op, allProjExprs: _*)
           None
@@ -1935,7 +1911,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
           
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
           windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
           windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
-          Some(result.setWindow(windowBuilder).build())
+          Some(builder.setWindow(windowBuilder).build())
         } else {
           None
         }
@@ -2002,7 +1978,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             return None
           }
           hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
-          Some(result.setHashAgg(hashAggBuilder).build())
+          Some(builder.setHashAgg(hashAggBuilder).build())
         } else {
           val modes = aggregateExpressions.map(_.mode).distinct
 
@@ -2048,7 +2024,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
               hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
             }
             hashAggBuilder.setModeValue(mode.getNumber)
-            Some(result.setHashAgg(hashAggBuilder).build())
+            Some(builder.setHashAgg(hashAggBuilder).build())
           } else {
             val allChildren: Seq[Expression] =
               groupingExpressions ++ aggregateExpressions ++ 
aggregateAttributes
@@ -2110,7 +2086,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             .setBuildSide(
               if (join.buildSide == BuildLeft) BuildSide.BuildLeft else 
BuildSide.BuildRight)
           condition.foreach(joinBuilder.setCondition)
-          Some(result.setHashJoin(joinBuilder).build())
+          Some(builder.setHashJoin(joinBuilder).build())
         } else {
           val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
           withInfo(join, allExprs: _*)
@@ -2200,7 +2176,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
             .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
             .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
           condition.map(joinBuilder.setCondition)
-          Some(result.setSortMergeJoin(joinBuilder).build())
+          Some(builder.setSortMergeJoin(joinBuilder).build())
         } else {
           val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
           withInfo(join, allExprs: _*)
@@ -2236,9 +2212,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
           scanBuilder.addAllFields(scanTypes.asJava)
 
           // Sink operators don't have children
-          result.clearChildren()
+          builder.clearChildren()
 
-          Some(result.setScan(scanBuilder).build())
+          Some(builder.setScan(scanBuilder).build())
         } else {
           // There are unsupported scan type
           val msg =
@@ -2249,15 +2225,29 @@ object QueryPlanSerde extends Logging with 
CometExprShim {
         }
 
       case op =>
-        // Emit warning if:
-        //  1. it is not Spark shuffle operator, which is handled separately
-        //  2. it is not a Comet operator
-        if (!op.nodeName.contains("Comet") && 
!op.isInstanceOf[ShuffleExchangeExec]) {
-          val msg = s"unsupported Spark operator: ${op.nodeName}"
-          emitWarning(msg)
-          withInfo(op, msg)
+        opSerdeMap.get(op.getClass) match {
+          case Some(handler) =>
+            handler.enabledConfig.foreach { enabledConfig =>
+              if (!enabledConfig.get(op.conf)) {
+                withInfo(
+                  op,
+                  s"Native support for operator ${op.getClass.getSimpleName} 
is disabled. " +
+                    s"Set ${enabledConfig.key}=true to enable it.")
+                return None
+              }
+            }
+            handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, 
builder, childOp: _*)
+          case _ =>
+            // Emit warning if:
+            //  1. it is not Spark shuffle operator, which is handled 
separately
+            //  2. it is not a Comet operator
+            if (!op.nodeName.contains("Comet") && 
!op.isInstanceOf[ShuffleExchangeExec]) {
+              val msg = s"unsupported Spark operator: ${op.nodeName}"
+              emitWarning(msg)
+              withInfo(op, msg)
+            }
+            None
         }
-        None
     }
   }
 
@@ -2416,6 +2406,38 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
   }
 }
 
+/**
+ * Trait for providing serialization logic for operators.
+ */
+trait CometOperatorSerde[T <: SparkPlan] {
+
+  /**
+   * Convert a Spark operator into a protocol buffer representation that can 
be passed into native
+   * code.
+   *
+   * @param op
+   *   The Spark operator.
+   * @param builder
+   *   The protobuf builder for the operator.
+   * @param childOp
+   *   Child operators that have already been converted to Comet.
+   * @return
+   *   Protocol buffer representation, or None if the operator could not be 
converted. In this
+   *   case it is expected that the input operator will have been tagged with 
reasons why it could
+   *   not be converted.
+   */
+  def convert(
+      op: T,
+      builder: Operator.Builder,
+      childOp: Operator*): Option[OperatorOuterClass.Operator]
+
+  /**
+   * Get the optional Comet configuration entry that is used to enable or 
disable native support
+   * for this operator.
+   */
+  def enabledConfig: Option[ConfigEntry[Boolean]]
+}
+
 /**
  * Trait for providing serialization logic for expressions.
  */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to