This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 168715106 chore: Refactor QueryPlanSerde to use idiomatic Scala and reduce verbosity (#1609) 168715106 is described below commit 168715106b4587cc509c4e022f72a70c50b9bf27 Author: Andy Grove <agr...@apache.org> AuthorDate: Mon Apr 7 11:09:20 2025 -0600 chore: Refactor QueryPlanSerde to use idiomatic Scala and reduce verbosity (#1609) * start refactor * more refactoring * address feedback --- .../apache/comet/CometSparkSessionExtensions.scala | 315 ++++++++------------- 1 file changed, 118 insertions(+), 197 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index daa1f19ea..183d04ee7 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -352,7 +352,7 @@ class CometSparkSessionExtensions */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def transform1(op: SparkPlan): Option[Operator] = { + def operator2Proto(op: SparkPlan): Option[Operator] = { if (op.children.forall(_.isInstanceOf[CometNativeExec])) { QueryPlanSerde.operator2Proto( op, @@ -366,6 +366,14 @@ class CometSparkSessionExtensions } } + /** + * Convert operator to proto and then apply a transformation to wrap the proto in a new + * plan. + */ + def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan): SparkPlan = { + operator2Proto(op).map(fun).getOrElse(op) + } + plan.transformUp { // Fully native scan for V1 case scan: CometScanExec @@ -384,97 +392,54 @@ class CometSparkSessionExtensions CometScanWrapper(nativeOp.get, cometOp) case op: ProjectExec => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometProjectExec( - nativeOp, - op, - op.output, - op.projectList, - op.child, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometProjectExec(_, op, op.output, op.projectList, op.child, SerializedPlan(None))) case op: FilterExec => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometFilterExec( - nativeOp, - op, - op.output, - op.condition, - op.child, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometFilterExec(_, op, op.output, op.condition, op.child, SerializedPlan(None))) case op: SortExec => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometSortExec( - nativeOp, - op, - op.output, - op.outputOrdering, - op.sortOrder, - op.child, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometSortExec( + _, + op, + op.output, + op.outputOrdering, + op.sortOrder, + op.child, + SerializedPlan(None))) case op: LocalLimitExec => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometLocalLimitExec(_, op, op.limit, op.child, SerializedPlan(None))) case op: GlobalLimitExec if op.offset == 0 => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometGlobalLimitExec(_, op, op.limit, op.child, SerializedPlan(None))) case op: CollectLimitExec if isCometNative(op.child) && CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf) && isCometShuffleEnabled(conf) && op.offset == 0 => - QueryPlanSerde.operator2Proto(op) match { - case Some(nativeOp) => - val offset = op.offset + QueryPlanSerde + .operator2Proto(op) + .map { nativeOp => val cometOp = - CometCollectLimitExec(op, op.limit, offset, op.child) + CometCollectLimitExec(op, op.limit, op.offset, op.child) CometSinkPlaceHolder(nativeOp, op, cometOp) - case None => - op - } + } + .getOrElse(op) case op: ExpandExec => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometExpandExec( - nativeOp, - op, - op.output, - op.projections, - op.child, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometExpandExec(_, op, op.output, op.projections, op.child, SerializedPlan(None))) case op: BaseAggregateExec if op.isInstanceOf[HashAggregateExec] || @@ -503,9 +468,9 @@ class CometSparkSessionExtensions if (sparkFinalMode) { op } else { - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => + newPlanWithProto( + op, + nativeOp => { val modes = aggExprs.map(_.mode).distinct // The aggExprs could be empty. For example, if the aggregate functions only have // distinct aggregate functions or only have group by, the aggExprs is empty and @@ -523,34 +488,28 @@ class CometSparkSessionExtensions modes.headOption, child, SerializedPlan(None)) - case None => - op - } + }) } } case op: ShuffledHashJoinExec if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) && op.children.forall(isCometNative) => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometHashJoinExec( - nativeOp, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometHashJoinExec( + _, + op, + op.output, + op.outputOrdering, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None))) case op: ShuffledHashJoinExec if !CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) => withInfo(op, "ShuffleHashJoin is not enabled") @@ -566,47 +525,39 @@ class CometSparkSessionExtensions case op: BroadcastHashJoinExec if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) && op.children.forall(isCometNative) => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometBroadcastHashJoinExec( - nativeOp, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometBroadcastHashJoinExec( + _, + op, + op.output, + op.outputOrdering, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None))) case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && op.children.forall(isCometNative) => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometSortMergeJoinExec( - nativeOp, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.left, - op.right, - SerializedPlan(None)) - case None => - op - } + newPlanWithProto( + op, + CometSortMergeJoinExec( + _, + op, + op.output, + op.outputOrdering, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.left, + op.right, + SerializedPlan(None))) case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && @@ -631,13 +582,13 @@ class CometSparkSessionExtensions case c @ CoalesceExec(numPartitions, child) if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf) && isCometNative(child) => - QueryPlanSerde.operator2Proto(c) match { - case Some(nativeOp) => + QueryPlanSerde + .operator2Proto(c) + .map { nativeOp => val cometOp = CometCoalesceExec(c, c.output, numPartitions, child) CometSinkPlaceHolder(nativeOp, c, cometOp) - case None => - c - } + } + .getOrElse(c) case c @ CoalesceExec(_, _) if !CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf) => withInfo(c, "Coalesce is not enabled") @@ -655,8 +606,9 @@ class CometSparkSessionExtensions .get(conf) && isCometShuffleEnabled(conf) && CometTakeOrderedAndProjectExec.isSupported(s) => - QueryPlanSerde.operator2Proto(s) match { - case Some(nativeOp) => + QueryPlanSerde + .operator2Proto(s) + .map { nativeOp => val cometOp = CometTakeOrderedAndProjectExec( s, @@ -666,9 +618,8 @@ class CometSparkSessionExtensions s.projectList, s.child) CometSinkPlaceHolder(nativeOp, s, cometOp) - case None => - s - } + } + .getOrElse(s) case s: TakeOrderedAndProjectExec => val info1 = createMessage( @@ -681,32 +632,26 @@ class CometSparkSessionExtensions s case w: WindowExec => - val newOp = transform1(w) - newOp match { - case Some(nativeOp) => - CometWindowExec( - nativeOp, - w, - w.output, - w.windowExpression, - w.partitionSpec, - w.orderSpec, - w.child, - SerializedPlan(None)) - case None => - w - } + newPlanWithProto( + w, + CometWindowExec( + _, + w, + w.output, + w.windowExpression, + w.partitionSpec, + w.orderSpec, + w.child, + SerializedPlan(None))) case u: UnionExec if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) && u.children.forall(isCometNative) => - QueryPlanSerde.operator2Proto(u) match { - case Some(nativeOp) => + newPlanWithProto( + u, { val cometOp = CometUnionExec(u, u.output, u.children) - CometSinkPlaceHolder(nativeOp, u, cometOp) - case None => - u - } + CometSinkPlaceHolder(_, u, cometOp) + }) case u: UnionExec if !CometConf.COMET_EXEC_UNION_ENABLED.get(conf) => withInfo(u, "Union is not enabled") @@ -721,25 +666,13 @@ class CometSparkSessionExtensions // For AQE broadcast stage on a Comet broadcast exchange case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => - val newOp = transform1(s) - newOp match { - case Some(nativeOp) => - CometSinkPlaceHolder(nativeOp, s, s) - case None => - s - } + newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) case s @ BroadcastQueryStageExec( _, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => - val newOp = transform1(s) - newOp match { - case Some(nativeOp) => - CometSinkPlaceHolder(nativeOp, s, s) - case None => - s - } + newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast // exchange. It is only used for Comet native execution. We only transform Spark broadcast @@ -795,13 +728,7 @@ class CometSparkSessionExtensions // For AQE shuffle stage on a Comet shuffle exchange case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => - val newOp = transform1(s) - newOp match { - case Some(nativeOp) => - CometSinkPlaceHolder(nativeOp, s, s) - case None => - s - } + newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) // For AQE shuffle stage on a reused Comet shuffle exchange // Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because @@ -810,13 +737,7 @@ class CometSparkSessionExtensions _, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => - val newOp = transform1(s) - newOp match { - case Some(nativeOp) => - CometSinkPlaceHolder(nativeOp, s, s) - case None => - s - } + newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) // Native shuffle for Comet operators case s: ShuffleExchangeExec => @@ -826,7 +747,7 @@ class CometSparkSessionExtensions val nativeShuffle: Option[SparkPlan] = if (nativePrecondition) { - val newOp = transform1(s) + val newOp = operator2Proto(s) newOp match { case Some(nativeOp) => // Switch to use Decimal128 regardless of precision, since Arrow native execution --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org