This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 168715106 chore: Refactor QueryPlanSerde to use idiomatic Scala and 
reduce verbosity (#1609)
168715106 is described below

commit 168715106b4587cc509c4e022f72a70c50b9bf27
Author: Andy Grove <agr...@apache.org>
AuthorDate: Mon Apr 7 11:09:20 2025 -0600

    chore: Refactor QueryPlanSerde to use idiomatic Scala and reduce verbosity 
(#1609)
    
    * start refactor
    
    * more refactoring
    
    * address feedback
---
 .../apache/comet/CometSparkSessionExtensions.scala | 315 ++++++++-------------
 1 file changed, 118 insertions(+), 197 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index daa1f19ea..183d04ee7 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -352,7 +352,7 @@ class CometSparkSessionExtensions
      */
     // spotless:on
     private def transform(plan: SparkPlan): SparkPlan = {
-      def transform1(op: SparkPlan): Option[Operator] = {
+      def operator2Proto(op: SparkPlan): Option[Operator] = {
         if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
           QueryPlanSerde.operator2Proto(
             op,
@@ -366,6 +366,14 @@ class CometSparkSessionExtensions
         }
       }
 
+      /**
+       * Convert operator to proto and then apply a transformation to wrap the 
proto in a new
+       * plan.
+       */
+      def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan): 
SparkPlan = {
+        operator2Proto(op).map(fun).getOrElse(op)
+      }
+
       plan.transformUp {
         // Fully native scan for V1
         case scan: CometScanExec
@@ -384,97 +392,54 @@ class CometSparkSessionExtensions
           CometScanWrapper(nativeOp.get, cometOp)
 
         case op: ProjectExec =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometProjectExec(
-                nativeOp,
-                op,
-                op.output,
-                op.projectList,
-                op.child,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometProjectExec(_, op, op.output, op.projectList, op.child, 
SerializedPlan(None)))
 
         case op: FilterExec =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometFilterExec(
-                nativeOp,
-                op,
-                op.output,
-                op.condition,
-                op.child,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometFilterExec(_, op, op.output, op.condition, op.child, 
SerializedPlan(None)))
 
         case op: SortExec =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSortExec(
-                nativeOp,
-                op,
-                op.output,
-                op.outputOrdering,
-                op.sortOrder,
-                op.child,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometSortExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.sortOrder,
+              op.child,
+              SerializedPlan(None)))
 
         case op: LocalLimitExec =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometLocalLimitExec(nativeOp, op, op.limit, op.child, 
SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometLocalLimitExec(_, op, op.limit, op.child, 
SerializedPlan(None)))
 
         case op: GlobalLimitExec if op.offset == 0 =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometGlobalLimitExec(nativeOp, op, op.limit, op.child, 
SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometGlobalLimitExec(_, op, op.limit, op.child, 
SerializedPlan(None)))
 
         case op: CollectLimitExec
             if isCometNative(op.child) && 
CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf)
               && isCometShuffleEnabled(conf)
               && op.offset == 0 =>
-          QueryPlanSerde.operator2Proto(op) match {
-            case Some(nativeOp) =>
-              val offset = op.offset
+          QueryPlanSerde
+            .operator2Proto(op)
+            .map { nativeOp =>
               val cometOp =
-                CometCollectLimitExec(op, op.limit, offset, op.child)
+                CometCollectLimitExec(op, op.limit, op.offset, op.child)
               CometSinkPlaceHolder(nativeOp, op, cometOp)
-            case None =>
-              op
-          }
+            }
+            .getOrElse(op)
 
         case op: ExpandExec =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometExpandExec(
-                nativeOp,
-                op,
-                op.output,
-                op.projections,
-                op.child,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometExpandExec(_, op, op.output, op.projections, op.child, 
SerializedPlan(None)))
 
         case op: BaseAggregateExec
             if op.isInstanceOf[HashAggregateExec] ||
@@ -503,9 +468,9 @@ class CometSparkSessionExtensions
             if (sparkFinalMode) {
               op
             } else {
-              val newOp = transform1(op)
-              newOp match {
-                case Some(nativeOp) =>
+              newPlanWithProto(
+                op,
+                nativeOp => {
                   val modes = aggExprs.map(_.mode).distinct
                   // The aggExprs could be empty. For example, if the 
aggregate functions only have
                   // distinct aggregate functions or only have group by, the 
aggExprs is empty and
@@ -523,34 +488,28 @@ class CometSparkSessionExtensions
                     modes.headOption,
                     child,
                     SerializedPlan(None))
-                case None =>
-                  op
-              }
+                })
             }
           }
 
         case op: ShuffledHashJoinExec
             if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
               op.children.forall(isCometNative) =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometHashJoinExec(
-                nativeOp,
-                op,
-                op.output,
-                op.outputOrdering,
-                op.leftKeys,
-                op.rightKeys,
-                op.joinType,
-                op.condition,
-                op.buildSide,
-                op.left,
-                op.right,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometHashJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.buildSide,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
 
         case op: ShuffledHashJoinExec if 
!CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
           withInfo(op, "ShuffleHashJoin is not enabled")
@@ -566,47 +525,39 @@ class CometSparkSessionExtensions
         case op: BroadcastHashJoinExec
             if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
               op.children.forall(isCometNative) =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometBroadcastHashJoinExec(
-                nativeOp,
-                op,
-                op.output,
-                op.outputOrdering,
-                op.leftKeys,
-                op.rightKeys,
-                op.joinType,
-                op.condition,
-                op.buildSide,
-                op.left,
-                op.right,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometBroadcastHashJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.buildSide,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
 
         case op: SortMergeJoinExec
             if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
               op.children.forall(isCometNative) =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSortMergeJoinExec(
-                nativeOp,
-                op,
-                op.output,
-                op.outputOrdering,
-                op.leftKeys,
-                op.rightKeys,
-                op.joinType,
-                op.condition,
-                op.left,
-                op.right,
-                SerializedPlan(None))
-            case None =>
-              op
-          }
+          newPlanWithProto(
+            op,
+            CometSortMergeJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
 
         case op: SortMergeJoinExec
             if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
@@ -631,13 +582,13 @@ class CometSparkSessionExtensions
         case c @ CoalesceExec(numPartitions, child)
             if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf)
               && isCometNative(child) =>
-          QueryPlanSerde.operator2Proto(c) match {
-            case Some(nativeOp) =>
+          QueryPlanSerde
+            .operator2Proto(c)
+            .map { nativeOp =>
               val cometOp = CometCoalesceExec(c, c.output, numPartitions, 
child)
               CometSinkPlaceHolder(nativeOp, c, cometOp)
-            case None =>
-              c
-          }
+            }
+            .getOrElse(c)
 
         case c @ CoalesceExec(_, _) if 
!CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf) =>
           withInfo(c, "Coalesce is not enabled")
@@ -655,8 +606,9 @@ class CometSparkSessionExtensions
               .get(conf)
               && isCometShuffleEnabled(conf) &&
               CometTakeOrderedAndProjectExec.isSupported(s) =>
-          QueryPlanSerde.operator2Proto(s) match {
-            case Some(nativeOp) =>
+          QueryPlanSerde
+            .operator2Proto(s)
+            .map { nativeOp =>
               val cometOp =
                 CometTakeOrderedAndProjectExec(
                   s,
@@ -666,9 +618,8 @@ class CometSparkSessionExtensions
                   s.projectList,
                   s.child)
               CometSinkPlaceHolder(nativeOp, s, cometOp)
-            case None =>
-              s
-          }
+            }
+            .getOrElse(s)
 
         case s: TakeOrderedAndProjectExec =>
           val info1 = createMessage(
@@ -681,32 +632,26 @@ class CometSparkSessionExtensions
           s
 
         case w: WindowExec =>
-          val newOp = transform1(w)
-          newOp match {
-            case Some(nativeOp) =>
-              CometWindowExec(
-                nativeOp,
-                w,
-                w.output,
-                w.windowExpression,
-                w.partitionSpec,
-                w.orderSpec,
-                w.child,
-                SerializedPlan(None))
-            case None =>
-              w
-          }
+          newPlanWithProto(
+            w,
+            CometWindowExec(
+              _,
+              w,
+              w.output,
+              w.windowExpression,
+              w.partitionSpec,
+              w.orderSpec,
+              w.child,
+              SerializedPlan(None)))
 
         case u: UnionExec
             if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) &&
               u.children.forall(isCometNative) =>
-          QueryPlanSerde.operator2Proto(u) match {
-            case Some(nativeOp) =>
+          newPlanWithProto(
+            u, {
               val cometOp = CometUnionExec(u, u.output, u.children)
-              CometSinkPlaceHolder(nativeOp, u, cometOp)
-            case None =>
-              u
-          }
+              CometSinkPlaceHolder(_, u, cometOp)
+            })
 
         case u: UnionExec if !CometConf.COMET_EXEC_UNION_ENABLED.get(conf) =>
           withInfo(u, "Union is not enabled")
@@ -721,25 +666,13 @@ class CometSparkSessionExtensions
 
         // For AQE broadcast stage on a Comet broadcast exchange
         case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) 
=>
-          val newOp = transform1(s)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSinkPlaceHolder(nativeOp, s, s)
-            case None =>
-              s
-          }
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
 
         case s @ BroadcastQueryStageExec(
               _,
               ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
               _) =>
-          val newOp = transform1(s)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSinkPlaceHolder(nativeOp, s, s)
-            case None =>
-              s
-          }
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
 
         // `CometBroadcastExchangeExec`'s broadcast output is not compatible 
with Spark's broadcast
         // exchange. It is only used for Comet native execution. We only 
transform Spark broadcast
@@ -795,13 +728,7 @@ class CometSparkSessionExtensions
 
         // For AQE shuffle stage on a Comet shuffle exchange
         case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
-          val newOp = transform1(s)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSinkPlaceHolder(nativeOp, s, s)
-            case None =>
-              s
-          }
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
 
         // For AQE shuffle stage on a reused Comet shuffle exchange
         // Note that we don't need to handle `ReusedExchangeExec` for non-AQE 
case, because
@@ -810,13 +737,7 @@ class CometSparkSessionExtensions
               _,
               ReusedExchangeExec(_, _: CometShuffleExchangeExec),
               _) =>
-          val newOp = transform1(s)
-          newOp match {
-            case Some(nativeOp) =>
-              CometSinkPlaceHolder(nativeOp, s, s)
-            case None =>
-              s
-          }
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
 
         // Native shuffle for Comet operators
         case s: ShuffleExchangeExec =>
@@ -826,7 +747,7 @@ class CometSparkSessionExtensions
 
           val nativeShuffle: Option[SparkPlan] =
             if (nativePrecondition) {
-              val newOp = transform1(s)
+              val newOp = operator2Proto(s)
               newOp match {
                 case Some(nativeOp) =>
                   // Switch to use Decimal128 regardless of precision, since 
Arrow native execution


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to