This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new d0ca41471 minor: Refactor to move some shuffle-related logic from 
`QueryPlanSerde` to `CometExecRule` (#2015)
d0ca41471 is described below

commit d0ca41471bfda53cf1856b6e3fbd5101b424037e
Author: Andy Grove <agr...@apache.org>
AuthorDate: Fri Jul 11 13:30:06 2025 -0600

    minor: Refactor to move some shuffle-related logic from `QueryPlanSerde` to 
`CometExecRule` (#2015)
---
 .../org/apache/comet/rules/CometExecRule.scala     | 172 ++++++++++++++++++---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 135 ----------------
 2 files changed, 152 insertions(+), 155 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 2383dd844..a1464a2e0 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, 
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, 
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, 
NamedExpression, Remainder}
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec}
@@ -34,13 +35,15 @@ import 
org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.types.{DoubleType, FloatType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, 
LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, 
TimestampType}
 
 import org.apache.comet.{CometConf, ExtendedExplainInfo}
 import org.apache.comet.CometConf.COMET_ANSI_MODE_ENABLED
 import org.apache.comet.CometSparkSessionExtensions._
 import org.apache.comet.serde.OperatorOuterClass.Operator
 import org.apache.comet.serde.QueryPlanSerde
+import org.apache.comet.serde.QueryPlanSerde.emitWarning
 
 /**
  * Spark physical optimizer rule for replacing Spark operators with Comet 
operators.
@@ -53,7 +56,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     plan.transformUp {
       case s: ShuffleExchangeExec
           if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
-            QueryPlanSerde.nativeShuffleSupported(s)._1 =>
+            nativeShuffleSupported(s)._1 =>
         logInfo("Comet extension enabled for Native Shuffle")
 
         // Switch to use Decimal128 regardless of precision, since Arrow 
native execution
@@ -65,7 +68,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       // (if configured)
       case s: ShuffleExchangeExec
           if (!s.child.supportsColumnar || isCometPlan(s.child)) && 
isCometJVMShuffleMode(conf) &&
-            QueryPlanSerde.columnarShuffleSupported(s)._1 &&
+            columnarShuffleSupported(s)._1 &&
             !isShuffleOperator(s.child) =>
         logInfo("Comet extension enabled for JVM Columnar Shuffle")
         CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
@@ -490,7 +493,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       case s: ShuffleExchangeExec =>
         val nativePrecondition = isCometShuffleEnabled(conf) &&
           isCometNativeShuffleMode(conf) &&
-          QueryPlanSerde.nativeShuffleSupported(s)._1
+          nativeShuffleSupported(s)._1
 
         val nativeShuffle: Option[SparkPlan] =
           if (nativePrecondition) {
@@ -517,7 +520,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
           // If the child of ShuffleExchangeExec is also a 
ShuffleExchangeExec, we should not
           // convert it to CometColumnarShuffle,
           if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
-            QueryPlanSerde.columnarShuffleSupported(s)._1 &&
+            columnarShuffleSupported(s)._1 &&
             !isShuffleOperator(s.child)) {
 
             val newOp = QueryPlanSerde.operator2Proto(s)
@@ -547,18 +550,12 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
           val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not 
enabled: $reason")
           val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
           val msg2 = createMessage(
-            isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
-              .nativeShuffleSupported(s)
-              ._1,
+            isShuffleEnabled && !columnarShuffleEnabled && 
!nativeShuffleSupported(s)._1,
             "Native shuffle: " +
-              s"${QueryPlanSerde.nativeShuffleSupported(s)._2}")
-          val typeInfo = QueryPlanSerde
-            .columnarShuffleSupported(s)
-            ._2
+              s"${nativeShuffleSupported(s)._2}")
+          val typeInfo = columnarShuffleSupported(s)._2
           val msg3 = createMessage(
-            isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
-              .columnarShuffleSupported(s)
-              ._1,
+            isShuffleEnabled && columnarShuffleEnabled && 
!columnarShuffleSupported(s)._1,
             "JVM shuffle: " +
               s"$typeInfo")
           withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
@@ -578,7 +575,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  def normalizePlan(plan: SparkPlan): SparkPlan = {
+  private def normalizePlan(plan: SparkPlan): SparkPlan = {
     plan.transformUp {
       case p: ProjectExec =>
         val newProjectList = 
p.projectList.map(normalize(_).asInstanceOf[NamedExpression])
@@ -595,7 +592,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
   // because they are handled well in Spark (e.g., 
`SQLOrderingUtil.compareFloats`). But the
   // comparison functions in arrow-rs do not normalize NaN and zero. So we 
need to normalize NaN
   // and zero for comparison operators in Comet.
-  def normalize(expr: Expression): Expression = {
+  private def normalize(expr: Expression): Expression = {
     expr.transformUp {
       case EqualTo(left, right) =>
         EqualTo(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
@@ -616,7 +613,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  def normalizeNaNAndZero(expr: Expression): Expression = {
+  private def normalizeNaNAndZero(expr: Expression): Expression = {
     expr match {
       case _: KnownFloatingPointNormalized => expr
       case FloatLiteral(f) if !f.equals(-0.0f) => expr
@@ -755,7 +752,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
    * Find the first Comet partial aggregate in the plan. If it reaches a Spark 
HashAggregate with
    * partial mode, it will return None.
    */
-  def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
+  private def findCometPartialAgg(plan: SparkPlan): 
Option[CometHashAggregateExec] = {
     plan.collectFirst {
       case agg: CometHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
         Some(agg)
@@ -770,7 +767,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
   /**
    * Returns true if a given spark plan is Comet shuffle operator.
    */
-  def isShuffleOperator(op: SparkPlan): Boolean = {
+  private def isShuffleOperator(op: SparkPlan): Boolean = {
     op match {
       case op: ShuffleQueryStageExec if 
op.plan.isInstanceOf[CometShuffleExchangeExec] => true
       case _: CometShuffleExchangeExec => true
@@ -778,4 +775,139 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       case _ => false
     }
   }
+
+  /**
+   * Whether the given Spark partitioning is supported by Comet native shuffle.
+   */
+  private def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, 
String) = {
+
+    /**
+     * Determine which data types are supported as hash-partition keys in 
native shuffle.
+     *
+     * Hash Partition Key determines how data should be collocated for 
operations like
+     * `groupByKey`, `reduceByKey` or `join`.
+     */
+    def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match {
+      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+          _: TimestampNTZType | _: DecimalType | _: DateType =>
+        true
+      case _ =>
+        false
+    }
+
+    val inputs = s.child.output
+    val partitioning = s.outputPartitioning
+    val conf = SQLConf.get
+    var msg = ""
+    val supported = partitioning match {
+      case HashPartitioning(expressions, _) =>
+        // native shuffle currently does not support complex types as 
partition keys
+        // due to lack of hashing support for those types
+        val supported =
+          expressions.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
+            expressions.forall(e => 
supportedHashPartitionKeyDataType(e.dataType)) &&
+            inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
+            
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
+        if (!supported) {
+          msg = s"unsupported Spark partitioning: $expressions"
+        }
+        supported
+      case SinglePartition =>
+        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
+      case RangePartitioning(ordering, _) =>
+        val supported = ordering.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
+          inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
+          
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
+        if (!supported) {
+          msg = s"unsupported Spark partitioning: $ordering"
+        }
+        supported
+      case _ =>
+        msg = s"unsupported Spark partitioning: 
${partitioning.getClass.getName}"
+        false
+    }
+
+    if (!supported) {
+      emitWarning(msg)
+      (false, msg)
+    } else {
+      (true, null)
+    }
+  }
+
+  /**
+   * Check if the datatypes of shuffle input are supported. This is used for 
Columnar shuffle
+   * which supports struct/array.
+   */
+  private def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, 
String) = {
+    val inputs = s.child.output
+    val partitioning = s.outputPartitioning
+    var msg = ""
+    val supported = partitioning match {
+      case HashPartitioning(expressions, _) =>
+        // columnar shuffle supports the same data types (including complex 
types) both for
+        // partition keys and for other columns
+        val supported =
+          expressions.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
+            expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
+            inputs.forall(attr => supportedShuffleDataType(attr.dataType))
+        if (!supported) {
+          msg = s"unsupported Spark partitioning expressions: $expressions"
+        }
+        supported
+      case SinglePartition =>
+        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
+      case RoundRobinPartitioning(_) =>
+        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
+      case RangePartitioning(orderings, _) =>
+        val supported =
+          orderings.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
+            orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
+            inputs.forall(attr => supportedShuffleDataType(attr.dataType))
+        if (!supported) {
+          msg = s"unsupported Spark partitioning expressions: $orderings"
+        }
+        supported
+      case _ =>
+        msg = s"unsupported Spark partitioning: 
${partitioning.getClass.getName}"
+        false
+    }
+
+    if (!supported) {
+      emitWarning(msg)
+      (false, msg)
+    } else {
+      (true, null)
+    }
+  }
+
+  /**
+   * Determine which data types are supported in a shuffle.
+   */
+  private def supportedShuffleDataType(dt: DataType): Boolean = dt match {
+    case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
+        _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
+        _: TimestampNTZType | _: DecimalType | _: DateType =>
+      true
+    case StructType(fields) =>
+      fields.forall(f => supportedShuffleDataType(f.dataType)) &&
+      // Java Arrow stream reader cannot work on duplicate field name
+      fields.map(f => f.name).distinct.length == fields.length
+    case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not 
supported
+    case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is 
not supported
+    case ArrayType(elementType, _) =>
+      supportedShuffleDataType(elementType)
+    case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not 
supported
+    case MapType(_, MapType(_, _, _), _) => false
+    case MapType(StructType(_), _, _) => false // TODO: struct map key/value 
is not supported
+    case MapType(_, StructType(_), _) => false
+    case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value 
is not supported
+    case MapType(_, ArrayType(_, _), _) => false
+    case MapType(keyType, valueType, _) =>
+      supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
+    case _ =>
+      false
+  }
+
 }
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 077faeb41..970329b28 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
 import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
 import org.apache.spark.sql.comet._
@@ -2725,140 +2724,6 @@ object QueryPlanSerde extends Logging with 
CometExprShim {
     }
   }
 
-  /**
-   * Check if the datatypes of shuffle input are supported. This is used for 
Columnar shuffle
-   * which supports struct/array.
-   */
-  def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
-    val inputs = s.child.output
-    val partitioning = s.outputPartitioning
-    var msg = ""
-    val supported = partitioning match {
-      case HashPartitioning(expressions, _) =>
-        // columnar shuffle supports the same data types (including complex 
types) both for
-        // partition keys and for other columns
-        val supported =
-          expressions.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
-            expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
-            inputs.forall(attr => supportedShuffleDataType(attr.dataType))
-        if (!supported) {
-          msg = s"unsupported Spark partitioning expressions: $expressions"
-        }
-        supported
-      case SinglePartition =>
-        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
-      case RoundRobinPartitioning(_) =>
-        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
-      case RangePartitioning(orderings, _) =>
-        val supported =
-          orderings.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
-            orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
-            inputs.forall(attr => supportedShuffleDataType(attr.dataType))
-        if (!supported) {
-          msg = s"unsupported Spark partitioning expressions: $orderings"
-        }
-        supported
-      case _ =>
-        msg = s"unsupported Spark partitioning: 
${partitioning.getClass.getName}"
-        false
-    }
-
-    if (!supported) {
-      emitWarning(msg)
-      (false, msg)
-    } else {
-      (true, null)
-    }
-  }
-
-  /**
-   * Whether the given Spark partitioning is supported by Comet native shuffle.
-   */
-  def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
-
-    /**
-     * Determine which data types are supported as hash-partition keys in 
native shuffle.
-     *
-     * Hash Partition Key determines how data should be collocated for 
operations like
-     * `groupByKey`, `reduceByKey` or `join`.
-     */
-    def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match {
-      case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-          _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-          _: TimestampNTZType | _: DecimalType | _: DateType =>
-        true
-      case _ =>
-        false
-    }
-
-    val inputs = s.child.output
-    val partitioning = s.outputPartitioning
-    val conf = SQLConf.get
-    var msg = ""
-    val supported = partitioning match {
-      case HashPartitioning(expressions, _) =>
-        // native shuffle currently does not support complex types as 
partition keys
-        // due to lack of hashing support for those types
-        val supported =
-          expressions.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
-            expressions.forall(e => 
supportedHashPartitionKeyDataType(e.dataType)) &&
-            inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
-            
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
-        if (!supported) {
-          msg = s"unsupported Spark partitioning: $expressions"
-        }
-        supported
-      case SinglePartition =>
-        inputs.forall(attr => supportedShuffleDataType(attr.dataType))
-      case RangePartitioning(ordering, _) =>
-        val supported = ordering.map(QueryPlanSerde.exprToProto(_, 
inputs)).forall(_.isDefined) &&
-          inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
-          
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
-        if (!supported) {
-          msg = s"unsupported Spark partitioning: $ordering"
-        }
-        supported
-      case _ =>
-        msg = s"unsupported Spark partitioning: 
${partitioning.getClass.getName}"
-        false
-    }
-
-    if (!supported) {
-      emitWarning(msg)
-      (false, msg)
-    } else {
-      (true, null)
-    }
-  }
-
-  /**
-   * Determine which data types are supported in a shuffle.
-   */
-  def supportedShuffleDataType(dt: DataType): Boolean = dt match {
-    case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: 
LongType |
-        _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: 
TimestampType |
-        _: TimestampNTZType | _: DecimalType | _: DateType =>
-      true
-    case StructType(fields) =>
-      fields.forall(f => supportedShuffleDataType(f.dataType)) &&
-      // Java Arrow stream reader cannot work on duplicate field name
-      fields.map(f => f.name).distinct.length == fields.length
-    case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not 
supported
-    case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is 
not supported
-    case ArrayType(elementType, _) =>
-      supportedShuffleDataType(elementType)
-    case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not 
supported
-    case MapType(_, MapType(_, _, _), _) => false
-    case MapType(StructType(_), _, _) => false // TODO: struct map key/value 
is not supported
-    case MapType(_, StructType(_), _) => false
-    case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value 
is not supported
-    case MapType(_, ArrayType(_, _), _) => false
-    case MapType(keyType, valueType, _) =>
-      supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
-    case _ =>
-      false
-  }
-
   // Utility method. Adds explain info if the result of calling exprToProto is 
None
   def optExprWithInfo(
       optExpr: Option[Expr],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to