This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 20be79df6 feat: CometExecRule refactor: Unify CometNativeExec creation
with Serde in CometOperatorSerde trait (#2768)
20be79df6 is described below
commit 20be79df681152512181e0424e1cbb648ab63158
Author: Andy Grove <[email protected]>
AuthorDate: Thu Nov 13 11:34:03 2025 -0700
feat: CometExecRule refactor: Unify CometNativeExec creation with Serde in
CometOperatorSerde trait (#2768)
---
.../org/apache/comet/rules/CometExecRule.scala | 214 +++------------------
.../apache/comet/serde/CometOperatorSerde.scala | 2 +
.../comet/serde/operator/CometAggregate.scala | 28 ++-
.../apache/comet/serde/operator/CometExpand.scala | 5 +-
.../apache/comet/serde/operator/CometFilter.scala | 4 +
.../comet/serde/operator/CometGlobalLimit.scala | 4 +
.../comet/serde/operator/CometHashJoin.scala | 33 ++++
.../comet/serde/operator/CometLocalLimit.scala | 5 +
.../comet/serde/operator/CometLocalTableScan.scala | 5 +
.../comet/serde/operator/CometNativeScan.scala | 5 +-
.../apache/comet/serde/operator/CometProject.scala | 5 +
.../apache/comet/serde/operator/CometSort.scala | 11 ++
.../comet/serde/operator/CometSortMergeJoin.scala | 15 ++
.../apache/comet/serde/operator/CometWindow.scala | 13 ++
.../org/apache/comet/exec/CometExecSuite.scala | 4 +-
.../spark/sql/comet/CometPlanStabilitySuite.scala | 2 +-
16 files changed, 159 insertions(+), 196 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index a0dc8ccb1..e94fdc861 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
import org.apache.comet.{CometConf, ExtendedExplainInfo}
import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
import org.apache.comet.CometSparkSessionExtensions._
-import org.apache.comet.rules.CometExecRule.opSerdeMap
+import org.apache.comet.rules.CometExecRule.cometNativeExecHandlers
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible,
OperatorOuterClass, QueryPlanSerde, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{serializeDataType,
supportedDataType}
@@ -55,7 +55,7 @@ object CometExecRule {
/**
* Mapping of Spark operator class to Comet operator handler.
*/
- val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
+ val cometNativeExecHandlers: Map[Class[_ <: SparkPlan],
CometOperatorSerde[_]] =
Map(
classOf[ProjectExec] -> CometProject,
classOf[FilterExec] -> CometFilter,
@@ -183,7 +183,7 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
// Fully native scan for V1
case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_DATAFUSION =>
val nativeOp = operator2Proto(scan).get
- CometNativeScanExec(nativeOp, scan.wrapped, scan.session)
+ CometNativeScan.createExec(nativeOp, scan)
// Comet JVM + native scan for V1 and V2
case op if isCometScan(op) =>
@@ -195,36 +195,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
val nativeOp = operator2Proto(cometOp)
CometScanWrapper(nativeOp.get, cometOp)
- case op: ProjectExec =>
- newPlanWithProto(
- op,
- CometProjectExec(_, op, op.output, op.projectList, op.child,
SerializedPlan(None)))
-
- case op: FilterExec =>
- newPlanWithProto(
- op,
- CometFilterExec(_, op, op.output, op.condition, op.child,
SerializedPlan(None)))
-
- case op: SortExec =>
- newPlanWithProto(
- op,
- CometSortExec(
- _,
- op,
- op.output,
- op.outputOrdering,
- op.sortOrder,
- op.child,
- SerializedPlan(None)))
-
- case op: LocalLimitExec =>
- newPlanWithProto(op, CometLocalLimitExec(_, op, op.limit, op.child,
SerializedPlan(None)))
-
- case op: GlobalLimitExec =>
- newPlanWithProto(
- op,
- CometGlobalLimitExec(_, op, op.limit, op.offset, op.child,
SerializedPlan(None)))
-
case op: CollectLimitExec =>
val fallbackReasons = new ListBuffer[String]()
if (!CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf)) {
@@ -250,116 +220,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- case op: ExpandExec =>
- newPlanWithProto(
- op,
- CometExpandExec(_, op, op.output, op.projections, op.child,
SerializedPlan(None)))
-
- case op: HashAggregateExec =>
- newPlanWithProto(
- op,
- nativeOp => {
- CometHashAggregateExec(
- nativeOp,
- op,
- op.output,
- op.groupingExpressions,
- op.aggregateExpressions,
- op.resultExpressions,
- op.child.output,
- op.child,
- SerializedPlan(None))
- })
-
- case op: ObjectHashAggregateExec =>
- newPlanWithProto(
- op,
- nativeOp => {
- CometHashAggregateExec(
- nativeOp,
- op,
- op.output,
- op.groupingExpressions,
- op.aggregateExpressions,
- op.resultExpressions,
- op.child.output,
- op.child,
- SerializedPlan(None))
- })
-
- case op: ShuffledHashJoinExec
- if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
- op.children.forall(isCometNative) =>
- newPlanWithProto(
- op,
- CometHashJoinExec(
- _,
- op,
- op.output,
- op.outputOrdering,
- op.leftKeys,
- op.rightKeys,
- op.joinType,
- op.condition,
- op.buildSide,
- op.left,
- op.right,
- SerializedPlan(None)))
-
- case op: ShuffledHashJoinExec if
!CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
- withInfo(op, "ShuffleHashJoin is not enabled")
-
- case op: ShuffledHashJoinExec if !op.children.forall(isCometNative) =>
- op
-
- case op: BroadcastHashJoinExec
- if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
- op.children.forall(isCometNative) =>
- newPlanWithProto(
- op,
- CometBroadcastHashJoinExec(
- _,
- op,
- op.output,
- op.outputOrdering,
- op.leftKeys,
- op.rightKeys,
- op.joinType,
- op.condition,
- op.buildSide,
- op.left,
- op.right,
- SerializedPlan(None)))
-
- case op: SortMergeJoinExec
- if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
- op.children.forall(isCometNative) =>
- newPlanWithProto(
- op,
- CometSortMergeJoinExec(
- _,
- op,
- op.output,
- op.outputOrdering,
- op.leftKeys,
- op.rightKeys,
- op.joinType,
- op.condition,
- op.left,
- op.right,
- SerializedPlan(None)))
-
- case op: SortMergeJoinExec
- if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
- !op.children.forall(isCometNative) =>
- op
-
- case op: SortMergeJoinExec if
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
- withInfo(op, "SortMergeJoin is not enabled")
-
- case op: SortMergeJoinExec if !op.children.forall(isCometNative) =>
- op
-
case c @ CoalesceExec(numPartitions, child)
if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf)
&& isCometNative(child) =>
@@ -405,19 +265,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
"TakeOrderedAndProject requires shuffle to be enabled")
withInfo(s, Seq(info1, info2).flatten.mkString(","))
- case w: WindowExec =>
- newPlanWithProto(
- w,
- CometWindowExec(
- _,
- w,
- w.output,
- w.windowExpression,
- w.partitionSpec,
- w.orderSpec,
- w.child,
- SerializedPlan(None)))
-
case u: UnionExec
if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) &&
u.children.forall(isCometNative) =>
@@ -476,16 +323,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
plan
}
- // this case should be checked only after the previous case checking for
a
- // child BroadcastExchange has been applied, otherwise that transform
- // never gets applied
- case op: BroadcastHashJoinExec if !op.children.forall(isCometNative) =>
- op
-
- case op: BroadcastHashJoinExec
- if !CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) =>
- withInfo(op, "BroadcastHashJoin is not enabled")
-
// For AQE shuffle stage on a Comet shuffle exchange
case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
@@ -548,19 +385,28 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
s
}
- case op: LocalTableScanExec =>
- if (CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.get(conf)) {
- operator2Proto(op)
- .map { nativeOp =>
- val cometOp = CometLocalTableScanExec(op, op.rows, op.output)
- CometScanWrapper(nativeOp, cometOp)
+ case op =>
+ // check if this is a fully native operator
+ cometNativeExecHandlers
+ .get(op.getClass)
+ .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) match {
+ case Some(handler) =>
+ if (op.children.forall(isCometNative)) {
+ if (isOperatorEnabled(handler, op)) {
+ val builder =
OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
+ val childOp =
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
+ childOp.foreach(builder.addChildren)
+ return handler
+ .convert(op, builder, childOp: _*)
+ .map(handler.createExec(_, op))
+ .getOrElse(op)
+ }
+ } else {
+ return op
}
- .getOrElse(op)
- } else {
- withInfo(op, "LocalTableScan is not enabled")
+ case _ =>
}
- case op =>
op match {
case _: CometPlan | _: AQEShuffleReadExec | _: BroadcastExchangeExec
|
_: BroadcastQueryStageExec | _: AdaptiveSparkPlanExec =>
@@ -1030,20 +876,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
childOp.foreach(builder.addChildren)
- // look for registered handler first
- val serde = opSerdeMap.get(op.getClass)
- serde match {
- case Some(handler) if isOperatorEnabled(handler, op) =>
- val opSerde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
- val maybeConverted = opSerde.convert(op, builder, childOp: _*)
- if (maybeConverted.isDefined) {
- return maybeConverted
- }
- case _ =>
- }
-
- // now handle special cases that cannot be handled as a simple mapping
from class name
- // and see if operator can be used as a sink
op match {
// Fully native scan for V1
@@ -1108,7 +940,7 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
// Emit warning if:
// 1. it is not Spark shuffle operator, which is handled separately
// 2. it is not a Comet operator
- if (serde.isEmpty && !op.nodeName.contains("Comet") &&
+ if (!op.nodeName.contains("Comet") &&
!op.isInstanceOf[ShuffleExchangeExec]) {
withInfo(op, s"unsupported Spark operator: ${op.nodeName}")
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
index 25ea2c123..3a2494591 100644
--- a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala
@@ -19,6 +19,7 @@
package org.apache.comet.serde
+import org.apache.spark.sql.comet.CometNativeExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.comet.ConfigEntry
@@ -65,4 +66,5 @@ trait CometOperatorSerde[T <: SparkPlan] {
builder: Operator.Builder,
childOp: Operator*): Option[OperatorOuterClass.Operator]
+ def createExec(nativeOp: Operator, op: T): CometNativeExec
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
index b0c359f08..3e541fe03 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
@@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
-import org.apache.spark.sql.comet.CometHashAggregateExec
+import org.apache.spark.sql.comet.{CometHashAggregateExec, CometNativeExec,
SerializedPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
@@ -207,6 +207,19 @@ object CometHashAggregate extends
CometOperatorSerde[HashAggregateExec] with Com
childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
doConvert(aggregate, builder, childOp: _*)
}
+
+ override def createExec(nativeOp: Operator, op: HashAggregateExec):
CometNativeExec = {
+ CometHashAggregateExec(
+ nativeOp,
+ op,
+ op.output,
+ op.groupingExpressions,
+ op.aggregateExpressions,
+ op.resultExpressions,
+ op.child.output,
+ op.child,
+ SerializedPlan(None))
+ }
}
object CometObjectHashAggregate
@@ -230,4 +243,17 @@ object CometObjectHashAggregate
doConvert(aggregate, builder, childOp: _*)
}
+
+ override def createExec(nativeOp: Operator, op: ObjectHashAggregateExec):
CometNativeExec = {
+ CometHashAggregateExec(
+ nativeOp,
+ op,
+ op.output,
+ op.groupingExpressions,
+ op.aggregateExpressions,
+ op.resultExpressions,
+ op.child.output,
+ op.child,
+ SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometExpand.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometExpand.scala
index ab5a58b06..5e9eb3f17 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometExpand.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometExpand.scala
@@ -22,6 +22,7 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.comet.{CometExpandExec, CometNativeExec,
SerializedPlan}
import org.apache.spark.sql.execution.ExpandExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -55,7 +56,9 @@ object CometExpand extends CometOperatorSerde[ExpandExec] {
withInfo(op, allProjExprs: _*)
None
}
-
}
+ override def createExec(nativeOp: Operator, op: ExpandExec): CometNativeExec
= {
+ CometExpandExec(nativeOp, op, op.output, op.projections, op.child,
SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometFilter.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometFilter.scala
index 96771b902..925f1fc18 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometFilter.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometFilter.scala
@@ -19,6 +19,7 @@
package org.apache.comet.serde.operator
+import org.apache.spark.sql.comet.{CometFilterExec, CometNativeExec,
SerializedPlan}
import org.apache.spark.sql.execution.FilterExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -49,4 +50,7 @@ object CometFilter extends CometOperatorSerde[FilterExec] {
}
}
+ override def createExec(nativeOp: Operator, op: FilterExec): CometNativeExec
= {
+ CometFilterExec(nativeOp, op, op.output, op.condition, op.child,
SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometGlobalLimit.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometGlobalLimit.scala
index b2df3cf72..abb4d752c 100644
---
a/spark/src/main/scala/org/apache/comet/serde/operator/CometGlobalLimit.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/operator/CometGlobalLimit.scala
@@ -19,6 +19,7 @@
package org.apache.comet.serde.operator
+import org.apache.spark.sql.comet.{CometGlobalLimitExec, CometNativeExec,
SerializedPlan}
import org.apache.spark.sql.execution.GlobalLimitExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -45,6 +46,9 @@ object CometGlobalLimit extends
CometOperatorSerde[GlobalLimitExec] {
withInfo(op, "No child operator")
None
}
+ }
+ override def createExec(nativeOp: Operator, op: GlobalLimitExec):
CometNativeExec = {
+ CometGlobalLimitExec(nativeOp, op, op.limit, op.offset, op.child,
SerializedPlan(None))
}
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometHashJoin.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometHashJoin.scala
index c58384e3a..ccb6434f6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometHashJoin.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometHashJoin.scala
@@ -24,6 +24,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec,
CometHashJoinExec, CometNativeExec, SerializedPlan}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec}
import org.apache.comet.{CometConf, ConfigEntry}
@@ -109,6 +110,22 @@ object CometBroadcastHashJoin extends
CometOperatorSerde[HashJoin] with CometHas
builder: Operator.Builder,
childOp: Operator*): Option[Operator] =
doConvert(join, builder, childOp: _*)
+
+ override def createExec(nativeOp: Operator, op: HashJoin): CometNativeExec =
{
+ CometBroadcastHashJoinExec(
+ nativeOp,
+ op,
+ op.output,
+ op.outputOrdering,
+ op.leftKeys,
+ op.rightKeys,
+ op.joinType,
+ op.condition,
+ op.buildSide,
+ op.left,
+ op.right,
+ SerializedPlan(None))
+ }
}
object CometShuffleHashJoin extends CometOperatorSerde[HashJoin] with
CometHashJoin {
@@ -121,4 +138,20 @@ object CometShuffleHashJoin extends
CometOperatorSerde[HashJoin] with CometHashJ
builder: Operator.Builder,
childOp: Operator*): Option[Operator] =
doConvert(join, builder, childOp: _*)
+
+ override def createExec(nativeOp: Operator, op: HashJoin): CometNativeExec =
{
+ CometHashJoinExec(
+ nativeOp,
+ op,
+ op.output,
+ op.outputOrdering,
+ op.leftKeys,
+ op.rightKeys,
+ op.joinType,
+ op.condition,
+ op.buildSide,
+ op.left,
+ op.right,
+ SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalLimit.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalLimit.scala
index 3e5fbdebb..7e14b6ce6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalLimit.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalLimit.scala
@@ -19,6 +19,7 @@
package org.apache.comet.serde.operator
+import org.apache.spark.sql.comet.{CometLocalLimitExec, CometNativeExec,
SerializedPlan}
import org.apache.spark.sql.execution.LocalLimitExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -48,4 +49,8 @@ object CometLocalLimit extends
CometOperatorSerde[LocalLimitExec] {
None
}
}
+
+ override def createExec(nativeOp: Operator, op: LocalLimitExec):
CometNativeExec = {
+ CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalTableScan.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalTableScan.scala
index e3e8538cf..bfeb448dc 100644
---
a/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalTableScan.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/operator/CometLocalTableScan.scala
@@ -21,6 +21,7 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
+import org.apache.spark.sql.comet.{CometLocalTableScanExec, CometNativeExec,
CometScanWrapper}
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -45,4 +46,8 @@ object CometLocalTableScan extends
CometOperatorSerde[LocalTableScanExec] {
.setArrowFfiSafe(false)
Some(builder.setScan(scanBuilder).build())
}
+
+ override def createExec(nativeOp: Operator, op: LocalTableScanExec):
CometNativeExec = {
+ CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala
index 2bc8b5526..586ed8c0e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala
@@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Literal
import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
-import org.apache.spark.sql.comet.CometScanExec
+import org.apache.spark.sql.comet.{CometNativeExec, CometNativeScanExec,
CometScanExec}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD,
PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD,
DataSourceRDDPartition}
import org.apache.spark.sql.internal.SQLConf
@@ -216,4 +216,7 @@ object CometNativeScan extends
CometOperatorSerde[CometScanExec] with Logging {
nativeScanBuilder.addFilePartitions(partitionBuilder.build())
}
+ override def createExec(nativeOp: Operator, op: CometScanExec):
CometNativeExec = {
+ CometNativeScanExec(nativeOp, op.wrapped, op.session)
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometProject.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometProject.scala
index 4ba02945d..92eeca58f 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometProject.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometProject.scala
@@ -21,6 +21,7 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
+import org.apache.spark.sql.comet.{CometNativeExec, CometProjectExec,
SerializedPlan}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.comet.{CometConf, ConfigEntry}
@@ -50,4 +51,8 @@ object CometProject extends CometOperatorSerde[ProjectExec] {
None
}
}
+
+ override def createExec(nativeOp: Operator, op: ProjectExec):
CometNativeExec = {
+ CometProjectExec(nativeOp, op, op.output, op.projectList, op.child,
SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometSort.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometSort.scala
index 39a1c5565..3929e4f11 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSort.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSort.scala
@@ -22,6 +22,7 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Descending, NullsFirst, NullsLast, SortOrder}
+import org.apache.spark.sql.comet.{CometNativeExec, CometSortExec,
SerializedPlan}
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType,
StructType}
@@ -114,4 +115,14 @@ object CometSort extends CometOperatorSerde[SortExec] {
}
}
+ override def createExec(nativeOp: Operator, op: SortExec): CometNativeExec =
{
+ CometSortExec(
+ nativeOp,
+ op,
+ op.output,
+ op.outputOrdering,
+ op.sortOrder,
+ op.child,
+ SerializedPlan(None))
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometSortMergeJoin.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometSortMergeJoin.scala
index a22230d7a..9296e2747 100644
---
a/spark/src/main/scala/org/apache/comet/serde/operator/CometSortMergeJoin.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/operator/CometSortMergeJoin.scala
@@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression,
ExpressionSet, SortOrder}
import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.comet.{CometNativeExec, CometSortMergeJoinExec,
SerializedPlan}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.types._
@@ -128,7 +129,21 @@ object CometSortMergeJoin extends
CometOperatorSerde[SortMergeJoinExec] {
withInfo(join, allExprs: _*)
None
}
+ }
+ override def createExec(nativeOp: Operator, op: SortMergeJoinExec):
CometNativeExec = {
+ CometSortMergeJoinExec(
+ nativeOp,
+ op,
+ op.output,
+ op.outputOrdering,
+ op.leftKeys,
+ op.rightKeys,
+ op.joinType,
+ op.condition,
+ op.left,
+ op.right,
+ SerializedPlan(None))
}
/**
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometWindow.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometWindow.scala
index 0cb32bd22..4ceb11b4b 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometWindow.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometWindow.scala
@@ -22,6 +22,7 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Expression, SortOrder, WindowExpression}
+import org.apache.spark.sql.comet.{CometNativeExec, CometWindowExec,
SerializedPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.WindowExec
@@ -88,6 +89,18 @@ object CometWindow extends CometOperatorSerde[WindowExec] {
}
+ override def createExec(nativeOp: Operator, op: WindowExec): CometNativeExec
= {
+ CometWindowExec(
+ nativeOp,
+ op,
+ op.output,
+ op.windowExpression,
+ op.partitionSpec,
+ op.orderSpec,
+ op.child,
+ SerializedPlan(None))
+ }
+
private def validatePartitionAndSortSpecsForWindowFunc(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index aadd1adfd..be9e80347 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -2108,7 +2108,9 @@ class CometExecSuite extends CometTestBase {
test("LocalTableScanExec spark fallback") {
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "false") {
val df = Seq.range(0, 10).toDF("id")
- checkSparkAnswerAndFallbackReason(df, "LocalTableScan is not enabled")
+ checkSparkAnswerAndFallbackReason(
+ df,
+ "Native support for operator LocalTableScanExec is disabled")
}
}
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index 3c3264a81..c4c79e952 100644
---
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -102,7 +102,7 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
val approvedFile = new File(dir, s"$filename.txt")
val actualFile = new File(tempDir, s"$name.actual.$filename.txt")
FileUtils.writeStringToFile(actualFile, plan, StandardCharsets.UTF_8)
- comparePlans("simplified", approvedFile, actualFile)
+ comparePlans(filename, approvedFile, actualFile)
}
private def comparePlans(planType: String, expectedFile: File, actualFile:
File): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]